MyraMath
Kernel.h
Go to the documentation of this file.
1 // ========================================================================= //
2 // This file is part of MyraMath, copyright (c) 2014-2019 by Ryan A Chilton //
3 // and distributed by MyraCore, LLC. See LICENSE.txt for license terms. //
4 // ========================================================================= //
5 
6 #ifndef MYRAMATH_MULTIFRONTAL_ZLDLH_KERNEL_H
7 #define MYRAMATH_MULTIFRONTAL_ZLDLH_KERNEL_H
8 
16 
17 #include <myramath/io/Streams.h>
19 #include <myramath/io/detail/vector.h>
20 #include <myramath/io/detail/complex.h>
21 
23 #include <myramath/dense/hetrf.h>
24 #include <myramath/dense/trsm.h>
25 #include <myramath/dense/swaps.h>
26 #include <myramath/dense/detail/nwork.h>
27 
28 #include <stdint.h>
29 
30 namespace myra {
31 
32 // Forward declarations.
33 class InputStream;
34 class OutputStream;
35 template<class Number> class MatrixRange;
36 template<class Number> class LowerMatrix;
37 
39 template<class Precision> class MYRAMATH_EXPORT ZLDLHKernel
40  {
41  public:
42  typedef std::complex<Precision> Number;
43 
45  explicit ZLDLHKernel()
46  { }
47 
50  { }
51 
54  { in >> P_swaps >> Q_swaps >> R >> n_plus >> n_minus; }
55 
57  void write(OutputStream& out) const
58  { out << P_swaps << Q_swaps << R << n_plus << n_minus; }
59 
61  uint64_t factor()
62  {
63  // LDLSwaps<Number> result = hetrf_inplace(L);
64  LDLSwaps<Number> result = hetrf_outplace(L);
65  P_swaps = result.P_swaps;
66  Q_swaps = result.Q_swaps;
67  R = result.R;
68  n_plus = result.n_plus;
69  n_minus = result.n_minus;
70  // Return flop count.
71  uint64_t n_work = L.size();
72  return n_work*(n_work+1)*(n_work+2)/6;
73  }
74 
76  // side = Solve by L from the 'L'eft or from the 'R'ight?
77  // op = Apply an operation to L? ('T'ranspose, 'H'ermitian, 'C'onjugate or 'N'othing)
78  uint64_t solveL(const MatrixRange<Number>& B, char side, char op) const
79  {
80  int N = this->size();
81  // Internally L is decomposed into P'*L*Q', so solving by it always takes a few steps.
82  if (side == 'L')
83  {
84  // Check size.
85  if (B.I != N) throw eprintf("ZLDLHKernel::solveL('L'eft), size mismatch B.I != N [%d != %d]", B.I, N);
86  // Solve L*X = B?
87  if (op == 'N')
88  {
89  swap_rows(P_swaps,B);
90  uint64_t w = trsm_nwork('L','N',L,B);
91  R.solve(B,'L','N');
92  swap_rows(Q_swaps,B);
93  return w;
94  }
95  // Solve transpose(L)*X = B?
96  else if (op == 'T')
97  {
98  iswap_rows(Q_swaps,B);
99  R.solve(B,'L','T');
100  uint64_t w = trsm_nwork('L','T',L,B);
101  iswap_rows(P_swaps,B);
102  return w;
103  }
104  // Solve hermitian(L)*X = B?
105  else if (op == 'H')
106  {
107  iswap_rows(Q_swaps,B);
108  R.solve(B,'L','H');
109  uint64_t w = trsm_nwork('L','H',L,B);
110  iswap_rows(P_swaps,B);
111  return w;
112  }
113  // Solve conjugate(L)*X = B?
114  else if (op == 'C')
115  {
116  swap_rows(P_swaps,B);
117  uint64_t w = trsm_nwork('L','C',L,B);
118  R.solve(B,'L','C');
119  swap_rows(Q_swaps,B);
120  return w;
121  }
122  else throw eprintf("ZLDLHKernel::solveL('L'eft), didn't understand op = %c", op);
123  }
124  else if (side == 'R')
125  {
126  // Check size.
127  if (B.J != N) throw eprintf("ZLDLHKernel::solveL('R'ight), size mismatch in B.J != N [%d != %d]", B.J, N);
128  // Solve X*L = B?
129  if (op == 'N')
130  {
131  swap_columns(Q_swaps,B);
132  R.solve(B,'R','N');
133  uint64_t w = trsm_nwork('R','N',L,B);
134  swap_columns(P_swaps,B);
135  return w;
136  }
137  // Solve X*transpose(L) = B?
138  else if (op == 'T')
139  {
140  iswap_columns(P_swaps,B);
141  uint64_t w = trsm_nwork('R','T',L,B);
142  R.solve(B,'R','T');
143  iswap_columns(Q_swaps,B);
144  return w;
145  }
146  // Solve X*hermitian(L) = B?
147  else if (op == 'H')
148  {
149  iswap_columns(P_swaps,B);
150  uint64_t w = trsm_nwork('R','H',L,B);
151  R.solve(B,'R','H');
152  iswap_columns(Q_swaps,B);
153  return w;
154  }
155  // Solve X*conjugate(L) = B?
156  else if (op == 'C')
157  {
158  swap_columns(Q_swaps,B);
159  R.solve(B,'R','C');
160  uint64_t w = trsm_nwork('R','C',L,B);
161  swap_columns(P_swaps,B);
162  return w;
163  }
164  else throw eprintf("ZLDLHKernel::solveL('R'ight), didn't understand op = %c", op);
165  }
166  else throw eprintf("ZLDLHKernel::solveL(), didn't understand side = %c", side);
167  }
168 
170  void solveI(const MatrixRange<Number>& B, char side) const
171  {
172  if (side == 'L')
173  B.bottom(n_minus) *= -Number(1);
174  else if (side == 'R')
175  B.right(n_minus) *= -Number(1);
176  else throw eprintf("ZLDLHKernel::solveI(), didn't understand side = %c", side);
177  }
178 
180  std::pair<int,int> inertia() const
181  { return std::pair<int,int>(n_plus, n_minus); }
182 
184  int size() const
185  { return L.size(); }
186 
187  private:
188 
189  // Points to underlying data.
191 
192  // For applying permutation P.
193  std::vector<int> P_swaps;
194 
195  // For applying permutation Q.
196  std::vector<int> Q_swaps;
197 
198  // For applying pivot rotations R.
200 
201  // Encodes inertia.
202  int n_plus;
203  int n_minus;
204 
205  };
206 
208 template<class Precision> class ReflectNumber< ZLDLHKernel<Precision> >
209  { public: typedef std::complex<Precision> type; };
210 
211 } // namespace myra
212 
213 #endif
Reflects Number trait for a Container, containers of Numbers (Matrix&#39;s, Vector&#39;s, etc) should special...
Definition: Number.h:55
std::pair< int, int > inertia() const
Returns inertia I, (n_plus, n_minus). Useful for schur downdates.
Definition: Kernel.h:180
Returns a std::runtime_error() whose message has been populated using printf()-style formatting...
int J
---------— Data members, all public ----------------—
Definition: MatrixRange.h:43
Represents a mutable LowerMatrixRange.
Definition: conjugate.h:28
int I
---------— Data members, all public ----------------—
Definition: MatrixRange.h:42
A mostly-identity matrix type, with the occasional Matrix22 at a specific diagonal offset (n...
Definition: PivotMatrix.h:29
LDL&#39; decompositions for real hermitian Matrix A (indefinite is fine).
Factors A into L*L&#39;, presents solve methods.
Definition: Kernel.h:39
ReaderWriter<T>, encapsulates a read()/write() pair for type T.
Range construct for a lower triangular matrix stored in rectangular packed format.
Definition: syntax.dox:1
ZLDLHKernel(LowerMatrixRange< Number > &A)
Seats reference to L, to be factor()&#39;d later.
Definition: Kernel.h:49
ZLDLHKernel()
Default constructor, initializes to 0 size.
Definition: Kernel.h:45
Routines for backsolving by a triangular Matrix or LowerMatrix.
Abstraction layer, serializable objects write themselves to these.
Definition: Streams.h:39
Various utility functions/classes related to scalar Number types.
Routines related to swap sequences, often used during pivoting.
Represents a mutable MatrixRange.
Definition: conjugate.h:26
Abstraction layer, deserializable objects read themselves from these.
Definition: Streams.h:47
uint64_t factor()
Factors A = L*I*L&#39;.
Definition: Kernel.h:61
MatrixRange< Number > bottom(int i) const
Returns the i bottommost rows, this(I-i:I,:)
Definition: MatrixRange.cpp:186
MatrixRange< Number > right(int j) const
Returns the j rightmost columns, this(:,J-j:J)
Definition: MatrixRange.cpp:235
ZLDLHKernel(LowerMatrixRange< Number > &A, InputStream &in)
Constructs from an InputStream, after seating reference to L.
Definition: Kernel.h:53
int size() const
Size inspector.
Definition: Kernel.h:184
void write(OutputStream &out) const
Writes to an OutputStream.
Definition: Kernel.h:57
void solveI(const MatrixRange< Number > &B, char side) const
Solves I*X=B or X*I=B, overwrites B with X.
Definition: Kernel.h:170
Bases classes for binary input/output streams.
Return type of sytrf_inplace() and hetrf_inplace(), holds pivoting metadata.
Definition: LDLSwaps.h:21
uint64_t solveL(const MatrixRange< Number > &B, char side, char op) const
Solves op(L)*X=B or X*op(L)=B, overwrites B with X.
Definition: Kernel.h:78