MyraMath
Kernel.h
Go to the documentation of this file.
1 // ========================================================================= //
2 // This file is part of MyraMath, copyright (c) 2014-2019 by Ryan A Chilton //
3 // and distributed by MyraCore, LLC. See LICENSE.txt for license terms. //
4 // ========================================================================= //
5 
6 #ifndef MYRAMATH_MULTIFRONTAL_ZLDLT_KERNEL_H
7 #define MYRAMATH_MULTIFRONTAL_ZLDLT_KERNEL_H
8 
16 
17 #include <myramath/io/Streams.h>
19 #include <myramath/io/detail/vector.h>
20 #include <myramath/io/detail/complex.h>
21 
24 #include <myramath/dense/hetrf.h>
25 #include <myramath/dense/trsm.h>
26 #include <myramath/dense/swaps.h>
27 #include <myramath/dense/detail/nwork.h>
28 
29 #include <stdint.h>
30 
31 namespace myra {
32 
33 // Forward declarations.
34 class InputStream;
35 class OutputStream;
36 template<class Number> class MatrixRange;
37 template<class Number> class LowerMatrix;
38 
40 template<class Precision> class MYRAMATH_EXPORT ZLDLTKernel
41  {
42  public:
43  typedef std::complex<Precision> Number;
44 
46  explicit ZLDLTKernel()
47  { }
48 
51  { }
52 
55  { in >> P_swaps >> Q_swaps >> R >> n_plus >> n_minus; }
56 
58  void write(OutputStream& out) const
59  { out << P_swaps << Q_swaps << R << n_plus << n_minus; }
60 
62  uint64_t factor()
63  {
64  // LDLSwaps<Number> result = sytrf_inplace(L);
65  LDLSwaps<Number> result = sytrf_outplace(L);
66  P_swaps = result.P_swaps;
67  Q_swaps = result.Q_swaps;
68  R = result.R;
69  n_plus = result.n_plus;
70  n_minus = result.n_minus;
71  // Return flop count.
72  uint64_t n_work = L.size();
73  return n_work*(n_work+1)*(n_work+2)/6;
74  }
75 
77  // side = Solve by L from the 'L'eft or from the 'R'ight?
78  // op = Apply an operation to L? ('T'ranspose, 'H'ermitian, 'C'onjugate or 'N'othing)
79  uint64_t solveL(const MatrixRange<Number>& B, char side, char op) const
80  {
81  int N = this->size();
82  // Internally L is decomposed into P'*L*Q', so solving by it always takes a few steps.
83  if (side == 'L')
84  {
85  // Check size.
86  if (B.I != N) throw eprintf("ZLDLTKernel::solveL('L'eft), size mismatch B.I != N [%d != %d]", B.I, N);
87  // Solve L*X = B?
88  if (op == 'N')
89  {
90  swap_rows(P_swaps,B);
91  uint64_t w = trsm_nwork('L','N',L,B);
92  R.solve(B,'L','N');
93  swap_rows(Q_swaps,B);
94  return w;
95  }
96  // Solve transpose(L)*X = B?
97  else if (op == 'T')
98  {
99  iswap_rows(Q_swaps,B);
100  R.solve(B,'L','T');
101  uint64_t w = trsm_nwork('L','T',L,B);
102  iswap_rows(P_swaps,B);
103  return w;
104  }
105  // Solve hermitian(L)*X = B?
106  else if (op == 'H')
107  {
108  iswap_rows(Q_swaps,B);
109  R.solve(B,'L','H');
110  uint64_t w = trsm_nwork('L','H',L,B);
111  iswap_rows(P_swaps,B);
112  return w;
113  }
114  // Solve conjugate(L)*X = B?
115  else if (op == 'C')
116  {
117  swap_rows(P_swaps,B);
118  uint64_t w = trsm_nwork('L','C',L,B);
119  R.solve(B,'L','C');
120  swap_rows(Q_swaps,B);
121  return w;
122  }
123  else throw eprintf("ZLDLTKernel::solveL('L'eft), didn't understand op = %c", op);
124  }
125  else if (side == 'R')
126  {
127  // Check size.
128  if (B.J != N) throw eprintf("ZLDLTKernel::solveL('R'ight), size mismatch B.J != N [%d != %d]", B.J, N);
129  // Solve X*L = B?
130  if (op == 'N')
131  {
132  swap_columns(Q_swaps,B);
133  R.solve(B,'R','N');
134  uint64_t w = trsm_nwork('R','N',L,B);
135  swap_columns(P_swaps,B);
136  return w;
137  }
138  // Solve X*transpose(L) = B?
139  else if (op == 'T')
140  {
141  iswap_columns(P_swaps,B);
142  uint64_t w = trsm_nwork('R','T',L,B);
143  R.solve(B,'R','T');
144  iswap_columns(Q_swaps,B);
145  return w;
146  }
147  // Solve X*hermitian(L) = B?
148  else if (op == 'H')
149  {
150  iswap_columns(P_swaps,B);
151  uint64_t w = trsm_nwork('R','H',L,B);
152  R.solve(B,'R','H');
153  iswap_columns(Q_swaps,B);
154  return w;
155  }
156  // Solve X*conjugate(L) = B?
157  else if (op == 'C')
158  {
159  swap_columns(Q_swaps,B);
160  R.solve(B,'R','C');
161  uint64_t w = trsm_nwork('R','C',L,B);
162  swap_columns(P_swaps,B);
163  return w;
164  }
165  else throw eprintf("ZLDLTKernel::solveL('R'ight), didn't understand op = %c", op);
166  }
167  else throw eprintf("ZLDLTKernel::solveL(), didn't understand side = %c", side);
168  }
169 
171  void solveI(const MatrixRange<Number>& B, char side) const
172  {
173  if (side == 'L')
174  B.bottom(n_minus) *= -Number(1);
175  else if (side == 'R')
176  B.right(n_minus) *= -Number(1);
177  else throw eprintf("ZLDLTKernel::solveI(), didn't understand side = %c", side);
178  }
179 
181  std::pair<int,int> inertia() const
182  { return std::pair<int,int>(n_plus, n_minus); }
183 
185  int size() const
186  { return L.size(); }
187 
188  private:
189 
190  // Points to underlying data.
192 
193  // For applying permutation P.
194  std::vector<int> P_swaps;
195 
196  // For applying permutation Q.
197  std::vector<int> Q_swaps;
198 
199  // For applying pivot rotations R.
201 
202  // Encodes inertia.
203  int n_plus;
204  int n_minus;
205 
206  };
207 
209 template<class Precision> class ReflectNumber< ZLDLTKernel<Precision> >
210  { public: typedef std::complex<Precision> type; };
211 
212 } // namespace myra
213 
214 #endif
Reflects Number trait for a Container, containers of Numbers (Matrix&#39;s, Vector&#39;s, etc) should special...
Definition: Number.h:55
Returns a std::runtime_error() whose message has been populated using printf()-style formatting...
Interface class for representing subranges of dense Matrix&#39;s.
int J
---------— Data members, all public ----------------—
Definition: MatrixRange.h:43
void write(OutputStream &out) const
Writes to an OutputStream.
Definition: Kernel.h:58
Represents a mutable LowerMatrixRange.
Definition: conjugate.h:28
int I
---------— Data members, all public ----------------—
Definition: MatrixRange.h:42
std::pair< int, int > inertia() const
Returns inertia I, (n_plus, n_minus). Useful for schur downdates.
Definition: Kernel.h:181
A mostly-identity matrix type, with the occasional Matrix22 at a specific diagonal offset (n...
Definition: PivotMatrix.h:29
LDL&#39; decompositions for real hermitian Matrix A (indefinite is fine).
ReaderWriter<T>, encapsulates a read()/write() pair for type T.
Range construct for a lower triangular matrix stored in rectangular packed format.
Definition: syntax.dox:1
Routines for backsolving by a triangular Matrix or LowerMatrix.
Abstraction layer, serializable objects write themselves to these.
Definition: Streams.h:39
Various utility functions/classes related to scalar Number types.
Routines related to swap sequences, often used during pivoting.
Represents a mutable MatrixRange.
Definition: conjugate.h:26
Abstraction layer, deserializable objects read themselves from these.
Definition: Streams.h:47
void solveI(const MatrixRange< Number > &B, char side) const
Solves I*X=B or X*I=B, overwrites B with X.
Definition: Kernel.h:171
uint64_t solveL(const MatrixRange< Number > &B, char side, char op) const
Solves op(L)*X=B or X*op(L)=B, overwrites B with X.
Definition: Kernel.h:79
ZLDLTKernel()
Default constructor, initializes to 0 size.
Definition: Kernel.h:46
Factors A into L*L&#39;, presents solve methods.
Definition: Kernel.h:40
ZLDLTKernel(LowerMatrixRange< Number > &A)
Seats reference to L, to be factor()&#39;d later.
Definition: Kernel.h:50
uint64_t factor()
Factors A = L*I*L&#39;.
Definition: Kernel.h:62
MatrixRange< Number > bottom(int i) const
Returns the i bottommost rows, this(I-i:I,:)
Definition: MatrixRange.cpp:186
MatrixRange< Number > right(int j) const
Returns the j rightmost columns, this(:,J-j:J)
Definition: MatrixRange.cpp:235
ZLDLTKernel(LowerMatrixRange< Number > &A, InputStream &in)
Constructs from an InputStream, after seating reference to L.
Definition: Kernel.h:54
int size() const
Size inspector.
Definition: Kernel.h:185
Bases classes for binary input/output streams.
Return type of sytrf_inplace() and hetrf_inplace(), holds pivoting metadata.
Definition: LDLSwaps.h:21