MyraMath
Kernel.h
Go to the documentation of this file.
1 // ========================================================================= //
2 // This file is part of MyraMath, copyright (c) 2014-2019 by Ryan A Chilton //
3 // and distributed by MyraCore, LLC. See LICENSE.txt for license terms. //
4 // ========================================================================= //
5 
6 #ifndef MYRAMATH_MULTIFRONTAL_LU_KERNEL_H
7 #define MYRAMATH_MULTIFRONTAL_LU_KERNEL_H
8 
16 
17 #include <myramath/io/Streams.h>
19 
21 #include <myramath/dense/getrf.h>
22 #include <myramath/dense/trsm.h>
23 #include <myramath/dense/swaps.h>
24 #include <myramath/dense/detail/nwork.h>
25 
26 #include <stdint.h>
27 
28 namespace myra {
29 
30 // Forward declarations.
31 class InputStream;
32 class OutputStream;
33 
35 template<class Number> class MYRAMATH_EXPORT LUKernel
36  {
37  public:
38 
40  explicit LUKernel()
41  { }
42 
44  explicit LUKernel(MatrixRange<Number>& A) : LU(A)
45  { }
46 
48  explicit LUKernel(MatrixRange<Number>& A, InputStream& in) : LU(A)
49  { in >> swaps; }
50 
52  void write(OutputStream& out) const
53  { out << swaps; }
54 
56  uint64_t factor()
57  {
58  swaps = getrf_inplace(LU);
59  uint64_t k_work = this->size();
60  return k_work*(k_work+1)*(2*k_work+1)/6;
61  }
62 
64  // side = Solve by L from the 'L'eft or from the 'R'ight?
65  // op = Apply an operation to L? ('T'ranspose, 'H'ermitian, 'C'onjugate or 'N'othing)
66  uint64_t solveL(const MatrixRange<Number>& B, char side, char op) const
67  {
68  int N = this->size();
69  if (side == 'L')
70  {
71  // Check size.
72  if (B.I != N) throw eprintf("LUKernel::solveL('L'eft), size mismatch B.I != N [%d != %d]", B.I, N);
73  if (op == 'N')
74  {
75  swap_rows(swaps,B);
76  uint64_t w = trsm_nwork('L','L','N',LU,B,'U');
77  return w;
78  }
79  else if (op == 'T')
80  {
81  uint64_t w = trsm_nwork('L','L','T',LU,B,'U');
82  iswap_rows(swaps,B);
83  return w;
84  }
85  else if (op == 'H')
86  {
87  uint64_t w = trsm_nwork('L','L','H',LU,B,'U');
88  iswap_rows(swaps,B);
89  return w;
90  }
91  else if (op == 'C')
92  {
93  swap_rows(swaps,B);
94  uint64_t w = trsm_nwork('L','L','C',LU,B,'U');
95  return w;
96  }
97  else throw eprintf("LUKernel::solveL('L'eft), didn't understand op = %c", op);
98  }
99  else if (side == 'R')
100  {
101  // Check size.
102  if (B.J != N) throw eprintf("LUKernel::solveL('R'ight), size mismatch B.J != N [%d != %d]", B.J, N);
103  if (op == 'N')
104  {
105  uint64_t w = trsm_nwork('R','L','N',LU,B,'U');
106  swap_columns(swaps,B);
107  return w;
108  }
109  else if (op == 'T')
110  {
111  iswap_columns(swaps,B);
112  uint64_t w = trsm_nwork('R','L','T',LU,B,'U');
113  return w;
114  }
115  else if (op == 'H')
116  {
117  iswap_columns(swaps,B);
118  uint64_t w = trsm_nwork('R','L','H',LU,B,'U');
119  return w;
120  }
121  else if (op == 'C')
122  {
123  uint64_t w = trsm_nwork('R','L','C',LU,B,'U');
124  swap_columns(swaps,B);
125  return w;
126  }
127  else throw eprintf("LUKernel::solveL('R'ight), didn't understand op = %c", op);
128  }
129  else throw eprintf("LUKernel::solveL(), didn't understand side = %c", side);
130  }
131 
133  // side = Solve by U from the 'L'eft or from the 'R'ight?
134  // op = Apply an operation to U? ('T'ranspose, 'H'ermitian, 'C'onjugate or 'N'othing)
135  uint64_t solveU(const MatrixRange<Number>& B, char side, char op) const
136  {
137  int N = this->size();
138  if (side == 'L')
139  {
140  // Check size.
141  if (B.I != N) throw eprintf("LUKernel::solveU('L'eft), size mismatch B.I != N [%d != %d]", B.I, N);
142  return trsm_nwork(side,'U',op,LU,B,'N');
143  }
144  else if (side == 'R')
145  {
146  // Check size.
147  if (B.J != N) throw eprintf("LUKernel::solveU('R'ight), size mismatch B.J != N [%d != %d]", B.J, N);
148  return trsm_nwork(side,'U',op,LU,B,'N');
149  }
150  else throw eprintf("LUKernel::solveU(), didn't understand side = %c", side);
151  }
152 
154  int size() const
155  { return LU.size().first; }
156 
157  private:
158 
159  // Points to underlying data.
161 
162  // Pivoting metadata - encodes row swaps.
163  std::vector<int> swaps;
164  };
165 
167 template<class Number> class ReflectNumber< LUKernel<Number> >
168  { public: typedef Number type; };
169 
170 } // namespace myra
171 
172 #endif
Reflects Number trait for a Container, containers of Numbers (Matrix&#39;s, Vector&#39;s, etc) should special...
Definition: Number.h:55
int size() const
Size inspector.
Definition: Kernel.h:154
Returns a std::runtime_error() whose message has been populated using printf()-style formatting...
Interface class for representing subranges of dense Matrix&#39;s.
int J
---------— Data members, all public ----------------—
Definition: MatrixRange.h:43
LUKernel(MatrixRange< Number > &A, InputStream &in)
Constructs from an InputStream, after seating reference to LU.
Definition: Kernel.h:48
int I
---------— Data members, all public ----------------—
Definition: MatrixRange.h:42
LUKernel(MatrixRange< Number > &A)
Seats reference to LU, to be factor()&#39;d later.
Definition: Kernel.h:44
ReaderWriter<T>, encapsulates a read()/write() pair for type T.
Definition: syntax.dox:1
Factors A into L*U, presents solve methods.
Definition: Kernel.h:35
LUKernel()
Default constructor, initializes to 0 size.
Definition: Kernel.h:40
Routines for backsolving by a triangular Matrix or LowerMatrix.
Abstraction layer, serializable objects write themselves to these.
Definition: Streams.h:39
Various utility functions/classes related to scalar Number types.
Routines related to swap sequences, often used during pivoting.
Represents a mutable MatrixRange.
Definition: conjugate.h:26
Abstraction layer, deserializable objects read themselves from these.
Definition: Streams.h:47
uint64_t solveL(const MatrixRange< Number > &B, char side, char op) const
Solves op(L)*X=B or X*op(L)=B, overwrites B with X.
Definition: Kernel.h:66
void write(OutputStream &out) const
Writes to an OutputStream.
Definition: Kernel.h:52
uint64_t solveU(const MatrixRange< Number > &B, char side, char op) const
Solves op(U)*X=B or X*op(U)=B, overwrites B with X.
Definition: Kernel.h:135
Bases classes for binary input/output streams.
uint64_t factor()
Factors A = P&#39;*L*U.
Definition: Kernel.h:56
General purpose A = P&#39;*L*U decomposition for square Matrix&#39;s.