ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/OpenMD/trunk/src/math/jama_lu.hpp
Revision: 1336
Committed: Tue Apr 7 20:16:26 2009 UTC (16 years, 1 month ago) by gezelter
File size: 7165 byte(s)
Log Message:
adding jama and tnt libraries for various linear algebra routines

File Contents

# User Rev Content
1 gezelter 1336 #ifndef JAMA_LU_H
2     #define JAMA_LU_H
3    
4     #include "tnt_array1d.hpp"
5     #include "tnt_array1d_utils.hpp"
6     #include "tnt_array2d.hpp"
7     #include "tnt_array2d_utils.hpp"
8     #include "tnt_math_utils.hpp"
9    
10     #include <algorithm>
11     //for min(), max() below
12    
13     using namespace TNT;
14     using namespace std;
15    
16     namespace JAMA
17     {
18    
19     /** LU Decomposition.
20     <P>
21     For an m-by-n matrix A with m >= n, the LU decomposition is an m-by-n
22     unit lower triangular matrix L, an n-by-n upper triangular matrix U,
23     and a permutation vector piv of length m so that A(piv,:) = L*U.
24     If m < n, then L is m-by-m and U is m-by-n.
25     <P>
26     The LU decompostion with pivoting always exists, even if the matrix is
27     singular, so the constructor will never fail. The primary use of the
28     LU decomposition is in the solution of square systems of simultaneous
29     linear equations. This will fail if isNonsingular() returns false.
30     */
31     template <class Real>
32     class LU
33     {
34    
35    
36    
37     /* Array for internal storage of decomposition. */
38     Array2D<Real> LU_;
39     int m, n, pivsign;
40     Array1D<int> piv;
41    
42    
43     Array2D<Real> permute_copy(const Array2D<Real> &A,
44     const Array1D<int> &piv, int j0, int j1)
45     {
46     int piv_length = piv.dim();
47    
48     Array2D<Real> X(piv_length, j1-j0+1);
49    
50    
51     for (int i = 0; i < piv_length; i++)
52     for (int j = j0; j <= j1; j++)
53     X[i][j-j0] = A[piv[i]][j];
54    
55     return X;
56     }
57    
58     Array1D<Real> permute_copy(const Array1D<Real> &A,
59     const Array1D<int> &piv)
60     {
61     int piv_length = piv.dim();
62     if (piv_length != A.dim())
63     return Array1D<Real>();
64    
65     Array1D<Real> x(piv_length);
66    
67    
68     for (int i = 0; i < piv_length; i++)
69     x[i] = A[piv[i]];
70    
71     return x;
72     }
73    
74    
75     public :
76    
77     /** LU Decomposition
78     @param A Rectangular matrix
79     @return LU Decomposition object to access L, U and piv.
80     */
81    
82     LU (const Array2D<Real> &A) : LU_(A.copy()), m(A.dim1()), n(A.dim2()),
83     piv(A.dim1())
84    
85     {
86    
87     // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
88    
89    
90     for (int i = 0; i < m; i++) {
91     piv[i] = i;
92     }
93     pivsign = 1;
94     Real *LUrowi = 0;;
95     Array1D<Real> LUcolj(m);
96    
97     // Outer loop.
98    
99     for (int j = 0; j < n; j++) {
100    
101     // Make a copy of the j-th column to localize references.
102    
103     for (int i = 0; i < m; i++) {
104     LUcolj[i] = LU_[i][j];
105     }
106    
107     // Apply previous transformations.
108    
109     for (int i = 0; i < m; i++) {
110     LUrowi = LU_[i];
111    
112     // Most of the time is spent in the following dot product.
113    
114     int kmax = min(i,j);
115     double s = 0.0;
116     for (int k = 0; k < kmax; k++) {
117     s += LUrowi[k]*LUcolj[k];
118     }
119    
120     LUrowi[j] = LUcolj[i] -= s;
121     }
122    
123     // Find pivot and exchange if necessary.
124    
125     int p = j;
126     for (int i = j+1; i < m; i++) {
127     if (abs(LUcolj[i]) > abs(LUcolj[p])) {
128     p = i;
129     }
130     }
131     if (p != j) {
132     int k=0;
133     for (k = 0; k < n; k++) {
134     double t = LU_[p][k];
135     LU_[p][k] = LU_[j][k];
136     LU_[j][k] = t;
137     }
138     k = piv[p];
139     piv[p] = piv[j];
140     piv[j] = k;
141     pivsign = -pivsign;
142     }
143    
144     // Compute multipliers.
145    
146     if ((j < m) && (LU_[j][j] != 0.0)) {
147     for (int i = j+1; i < m; i++) {
148     LU_[i][j] /= LU_[j][j];
149     }
150     }
151     }
152     }
153    
154    
155     /** Is the matrix nonsingular?
156     @return 1 (true) if upper triangular factor U (and hence A)
157     is nonsingular, 0 otherwise.
158     */
159    
160     int isNonsingular () {
161     for (int j = 0; j < n; j++) {
162     if (LU_[j][j] == 0)
163     return 0;
164     }
165     return 1;
166     }
167    
168     /** Return lower triangular factor
169     @return L
170     */
171    
172     Array2D<Real> getL () {
173     Array2D<Real> L_(m,n);
174     for (int i = 0; i < m; i++) {
175     for (int j = 0; j < n; j++) {
176     if (i > j) {
177     L_[i][j] = LU_[i][j];
178     } else if (i == j) {
179     L_[i][j] = 1.0;
180     } else {
181     L_[i][j] = 0.0;
182     }
183     }
184     }
185     return L_;
186     }
187    
188     /** Return upper triangular factor
189     @return U portion of LU factorization.
190     */
191    
192     Array2D<Real> getU () {
193     Array2D<Real> U_(n,n);
194     for (int i = 0; i < n; i++) {
195     for (int j = 0; j < n; j++) {
196     if (i <= j) {
197     U_[i][j] = LU_[i][j];
198     } else {
199     U_[i][j] = 0.0;
200     }
201     }
202     }
203     return U_;
204     }
205    
206     /** Return pivot permutation vector
207     @return piv
208     */
209    
210     Array1D<int> getPivot () {
211     return piv;
212     }
213    
214    
215     /** Compute determinant using LU factors.
216     @return determinant of A, or 0 if A is not square.
217     */
218    
219     Real det () {
220     if (m != n) {
221     return Real(0);
222     }
223     Real d = Real(pivsign);
224     for (int j = 0; j < n; j++) {
225     d *= LU_[j][j];
226     }
227     return d;
228     }
229    
230     /** Solve A*X = B
231     @param B A Matrix with as many rows as A and any number of columns.
232     @return X so that L*U*X = B(piv,:), if B is nonconformant, returns
233     0x0 (null) array.
234     */
235    
236     Array2D<Real> solve (const Array2D<Real> &B)
237     {
238    
239     /* Dimensions: A is mxn, X is nxk, B is mxk */
240    
241     if (B.dim1() != m) {
242     return Array2D<Real>(0,0);
243     }
244     if (!isNonsingular()) {
245     return Array2D<Real>(0,0);
246     }
247    
248     // Copy right hand side with pivoting
249     int nx = B.dim2();
250    
251    
252     Array2D<Real> X = permute_copy(B, piv, 0, nx-1);
253    
254     // Solve L*Y = B(piv,:)
255     for (int k = 0; k < n; k++) {
256     for (int i = k+1; i < n; i++) {
257     for (int j = 0; j < nx; j++) {
258     X[i][j] -= X[k][j]*LU_[i][k];
259     }
260     }
261     }
262     // Solve U*X = Y;
263     for (int k = n-1; k >= 0; k--) {
264     for (int j = 0; j < nx; j++) {
265     X[k][j] /= LU_[k][k];
266     }
267     for (int i = 0; i < k; i++) {
268     for (int j = 0; j < nx; j++) {
269     X[i][j] -= X[k][j]*LU_[i][k];
270     }
271     }
272     }
273     return X;
274     }
275    
276    
277     /** Solve A*x = b, where x and b are vectors of length equal
278     to the number of rows in A.
279    
280     @param b a vector (Array1D> of length equal to the first dimension
281     of A.
282     @return x a vector (Array1D> so that L*U*x = b(piv), if B is nonconformant,
283     returns 0x0 (null) array.
284     */
285    
286     Array1D<Real> solve (const Array1D<Real> &b)
287     {
288    
289     /* Dimensions: A is mxn, X is nxk, B is mxk */
290    
291     if (b.dim1() != m) {
292     return Array1D<Real>();
293     }
294     if (!isNonsingular()) {
295     return Array1D<Real>();
296     }
297    
298    
299     Array1D<Real> x = permute_copy(b, piv);
300    
301     // Solve L*Y = B(piv)
302     for (int k = 0; k < n; k++) {
303     for (int i = k+1; i < n; i++) {
304     x[i] -= x[k]*LU_[i][k];
305     }
306     }
307    
308     // Solve U*X = Y;
309     for (int k = n-1; k >= 0; k--) {
310     x[k] /= LU_[k][k];
311     for (int i = 0; i < k; i++)
312     x[i] -= x[k]*LU_[i][k];
313     }
314    
315    
316     return x;
317     }
318    
319     }; /* class LU */
320    
321     } /* namespace JAMA */
322    
323     #endif
324     /* JAMA_LU_H */