Scythe-1.0.3
|
00001 /* 00002 * Scythe Statistical Library Copyright (C) 2000-2002 Andrew D. Martin 00003 * and Kevin M. Quinn; 2002-present Andrew D. Martin, Kevin M. Quinn, 00004 * and Daniel Pemstein. All Rights Reserved. 00005 * 00006 * This program is free software; you can redistribute it and/or 00007 * modify under the terms of the GNU General Public License as 00008 * published by Free Software Foundation; either version 2 of the 00009 * License, or (at your option) any later version. See the text files 00010 * COPYING and LICENSE, distributed with this source code, for further 00011 * information. 00012 * -------------------------------------------------------------------- 00013 * scythestat/ide.h 00014 * 00015 * 00016 */ 00017 00042 /* TODO: This interface exposes the user to too much implementation. 00043 * We need a solve function and a solver object. By default, solve 00044 * would run lu_solve and the solver factory would return lu_solvers 00045 * (or perhaps a solver object encapsulating an lu_solver). Users 00046 * could choose cholesky when appropriate. Down the road, qr or svd 00047 * would become the default and we'd be able to handle non-square 00048 * matrices. Instead of doing an lu_decomp or a cholesky and keeping 00049 * track of the results to repeatedly solve for different b's with A 00050 * fixed in Ax=b, you'd just call the operator() on your solver object 00051 * over and over, passing the new b each time. No decomposition 00052 * specific solvers (except as toggles to the solver object and 00053 * solve function). We'd still provide cholesky and lu_decomp. We 00054 * could also think about a similar approach to inversion (one 00055 * inversion function with an option for method). 00056 * 00057 * If virtual dispatch in C++ wasn't such a performance killer (no 00058 * compiler optimization across virtual calls!!!) there would be an 00059 * obvious implementation of this interface using simple polymorphism. 00060 * Unfortunately, we need compile-time typing to maintain performance 00061 * and makes developing a clean interface that doesn't force users to 00062 * be template wizards much harder. Initial experiments with the 00063 * Barton and Nackman trick were ugly. The engine approach might work 00064 * a bit better but has its problems too. This is not going to get 00065 * done for the 1.0 release, but it is something we should come back 00066 * to. 00067 * 00068 */ 00069 00070 #ifndef SCYTHE_IDE_H 00071 #define SCYTHE_IDE_H 00072 00073 #ifdef SCYTHE_COMPILE_DIRECT 00074 #include "matrix.h" 00075 #include "error.h" 00076 #include "defs.h" 00077 #ifdef SCYTHE_LAPACK 00078 #include "lapack.h" 00079 #include "stat.h" 00080 #endif 00081 #else 00082 #include "scythestat/matrix.h" 00083 #include "scythestat/error.h" 00084 #include "scythestat/defs.h" 00085 #ifdef SCYTHE_LAPACK 00086 #include "scythestat/lapack.h" 00087 #include "scythestat/stat.h" 00088 #endif 00089 #endif 00090 00091 #include <cmath> 00092 #include <algorithm> 00093 #include <complex> 00094 00095 namespace scythe { 00096 00097 namespace { 00098 typedef unsigned int uint; 00099 } 00100 00124 template <matrix_order RO, matrix_style RS, typename T, 00125 matrix_order PO, matrix_style PS> 00126 Matrix<T, RO, RS> 00127 cholesky (const Matrix<T, PO, PS>& A) 00128 { 00129 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 00130 "Matrix not square"); 00131 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00132 "Matrix is NULL"); 00133 // Rounding errors can make this problematic. Leaving out for now 00134 //SCYTHE_CHECK_20(! A.isSymmetric(), scythe_type_error, 00135 // "Matrix not symmetric"); 00136 00137 Matrix<T,RO,Concrete> temp (A.rows(), A.cols(), false); 00138 T h; 00139 00140 if (PO == Row) { // row-major optimized 00141 for (uint i = 0; i < A.rows(); ++i) { 00142 for (uint j = i; j < A.cols(); ++j) { 00143 h = A(i,j); 00144 for (uint k = 0; k < i; ++k) 00145 h -= temp(i, k) * temp(j, k); 00146 if (i == j) { 00147 SCYTHE_CHECK_20(h <= (T) 0, scythe_type_error, 00148 "Matrix not positive definite"); 00149 00150 temp(i,i) = std::sqrt(h); 00151 } else { 00152 temp(j,i) = (((T) 1) / temp(i,i)) * h; 00153 temp(i,j) = (T) 0; 00154 } 00155 } 00156 } 00157 } else { // col-major optimized 00158 for (uint j = 0; j < A.cols(); ++j) { 00159 for (uint i = j; i < A.rows(); ++i) { 00160 h = A(i, j); 00161 for (uint k = 0; k < j; ++k) 00162 h -= temp(j, k) * temp(i, k); 00163 if (i == j) { 00164 SCYTHE_CHECK_20(h <= (T) 0, scythe_type_error, 00165 "Matrix not positive definite"); 00166 temp(j,j) = std::sqrt(h); 00167 } else { 00168 temp(i,j) = (((T) 1) / temp(j,j)) * h; 00169 temp(j,i) = (T) 0; 00170 } 00171 } 00172 } 00173 } 00174 00175 SCYTHE_VIEW_RETURN(T, RO, RS, temp) 00176 } 00177 00178 template <typename T, matrix_order O, matrix_style S> 00179 Matrix<T, O, Concrete> 00180 cholesky (const Matrix<T,O,S>& A) 00181 { 00182 return cholesky<O,Concrete>(A); 00183 } 00184 00185 namespace { 00186 /* This internal routine encapsulates the 00187 * algorithm used within chol_solve and lu_solve. 00188 */ 00189 template <typename T, 00190 matrix_order PO1, matrix_style PS1, 00191 matrix_order PO2, matrix_style PS2, 00192 matrix_order PO3, matrix_style PS3> 00193 inline void 00194 solve(const Matrix<T,PO1,PS1>& L, const Matrix<T,PO2,PS2>& U, 00195 Matrix<T,PO3,PS3> b, T* x, T* y) 00196 { 00197 T sum; 00198 00199 /* TODO: Consider optimizing for ordering. Experimentation 00200 * shows performance gains are probably minor (compared col-major 00201 * with and without lapack solve routines). 00202 */ 00203 // solve M*y = b 00204 for (uint i = 0; i < b.size(); ++i) { 00205 sum = T (0); 00206 for (uint j = 0; j < i; ++j) { 00207 sum += L(i,j) * y[j]; 00208 } 00209 y[i] = (b[i] - sum) / L(i, i); 00210 } 00211 00212 // solve M'*x = y 00213 if (U.isNull()) { // A= LL^T 00214 for (int i = b.size() - 1; i >= 0; --i) { 00215 sum = T(0); 00216 for (uint j = i + 1; j < b.size(); ++j) { 00217 sum += L(j,i) * x[j]; 00218 } 00219 x[i] = (y[i] - sum) / L(i, i); 00220 } 00221 } else { // A = LU 00222 for (int i = b.size() - 1; i >= 0; --i) { 00223 sum = T(0); 00224 for (uint j = i + 1; j < b.size(); ++j) { 00225 sum += U(i,j) * x[j]; 00226 } 00227 x[i] = (y[i] - sum) / U(i, i); 00228 } 00229 } 00230 } 00231 } 00232 00259 template <matrix_order RO, matrix_style RS, typename T, 00260 matrix_order PO1, matrix_style PS1, 00261 matrix_order PO2, matrix_style PS2, 00262 matrix_order PO3, matrix_style PS3> 00263 Matrix<T,RO,RS> 00264 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b, 00265 const Matrix<T,PO3,PS3>& M) 00266 { 00267 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00268 "A is NULL") 00269 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, 00270 "b must be a column vector"); 00271 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, 00272 "A and b do not conform"); 00273 SCYTHE_CHECK_10(A.rows() != M.rows(), scythe_conformation_error, 00274 "A and M do not conform"); 00275 SCYTHE_CHECK_10(! M.isSquare(), scythe_dimension_error, 00276 "M must be square"); 00277 00278 T *y = new T[A.rows()]; 00279 T *x = new T[A.rows()]; 00280 00281 solve(M, Matrix<>(), b, x, y); 00282 00283 Matrix<T,RO,RS> result(A.rows(), 1, x); 00284 00285 delete[]x; 00286 delete[]y; 00287 00288 return result; 00289 } 00290 00291 template <typename T, matrix_order PO1, matrix_style PS1, 00292 matrix_order PO2, matrix_style PS2, 00293 matrix_order PO3, matrix_style PS3> 00294 Matrix<T,PO1,Concrete> 00295 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b, 00296 const Matrix<T,PO3,PS3>& M) 00297 { 00298 return chol_solve<PO1,Concrete>(A,b,M); 00299 } 00300 00325 template <matrix_order RO, matrix_style RS, typename T, 00326 matrix_order PO1, matrix_style PS1, 00327 matrix_order PO2, matrix_style PS2> 00328 Matrix<T,RO,RS> 00329 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b) 00330 { 00331 /* NOTE: cholesky() call does check for square/posdef of A, 00332 * and the overloaded chol_solve call handles dimensions 00333 */ 00334 00335 return chol_solve<RO,RS>(A, b, cholesky<RO,Concrete>(A)); 00336 } 00337 00338 template <typename T, matrix_order PO1, matrix_style PS1, 00339 matrix_order PO2, matrix_style PS2> 00340 Matrix<T,PO1,Concrete> 00341 chol_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b) 00342 { 00343 return chol_solve<PO1,Concrete>(A, b); 00344 } 00345 00346 00369 template <matrix_order RO, matrix_style RS, typename T, 00370 matrix_order PO1, matrix_style PS1, 00371 matrix_order PO2, matrix_style PS2> 00372 Matrix<T,RO,RS> 00373 invpd (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& M) 00374 { 00375 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00376 "A is NULL") 00377 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 00378 "A is not square") 00379 SCYTHE_CHECK_10(A.rows() != M.cols() || A.cols() != M.rows(), 00380 scythe_conformation_error, "A and M do not conform"); 00381 00382 // for chol_solve block 00383 T *y = new T[A.rows()]; 00384 T *x = new T[A.rows()]; 00385 Matrix<T, RO, Concrete> b(A.rows(), 1); // full of zeros 00386 Matrix<T, RO, Concrete> null; 00387 00388 // For final answer 00389 Matrix<T, RO, Concrete> Ainv(A.rows(), A.cols(), false); 00390 00391 for (uint k = 0; k < A.rows(); ++k) { 00392 b[k] = (T) 1; 00393 00394 solve(M, null, b, x, y); 00395 00396 b[k] = (T) 0; 00397 for (uint l = 0; l < A.rows(); ++l) 00398 Ainv(l,k) = x[l]; 00399 } 00400 00401 delete[] y; 00402 delete[] x; 00403 00404 SCYTHE_VIEW_RETURN(T, RO, RS, Ainv) 00405 } 00406 00407 template <typename T, matrix_order PO1, matrix_style PS1, 00408 matrix_order PO2, matrix_style PS2> 00409 Matrix<T,PO1,Concrete> 00410 invpd (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& M) 00411 { 00412 return invpd<PO1,Concrete>(A, M); 00413 } 00414 00435 template <matrix_order RO, matrix_style RS, typename T, 00436 matrix_order PO, matrix_style PS> 00437 Matrix<T, RO, RS> 00438 invpd (const Matrix<T, PO, PS>& A) 00439 { 00440 // Cholesky checks to see if A is square and symmetric 00441 00442 return invpd<RO,RS>(A, cholesky<RO,Concrete>(A)); 00443 } 00444 00445 template <typename T, matrix_order O, matrix_style S> 00446 Matrix<T, O, Concrete> 00447 invpd (const Matrix<T,O,S>& A) 00448 { 00449 return invpd<O,Concrete>(A); 00450 } 00451 00452 /* This code is based on Algorithm 3.4.1 of Golub and Van Loan 3rd 00453 * edition, 1996. Major difference is in how the output is 00454 * structured. Returns the sign of the row permutation (used by 00455 * det). Internal function, doesn't need doxygen. 00456 */ 00457 namespace { 00458 template <matrix_order PO1, matrix_style PS1, typename T, 00459 matrix_order PO2, matrix_order PO3, matrix_order PO4> 00460 inline T 00461 lu_decomp_alg(Matrix<T,PO1,PS1>& A, Matrix<T,PO2,Concrete>& L, 00462 Matrix<T,PO3,Concrete>& U, 00463 Matrix<unsigned int, PO4, Concrete>& perm_vec) 00464 { 00465 if (A.isRowVector()) { 00466 L = Matrix<T,PO2,Concrete> (1, 1, true, 1); // all 1s 00467 U = A; 00468 perm_vec = Matrix<uint, PO4, Concrete>(1, 1); // all 0s 00469 return (T) 0; 00470 } 00471 00472 L = U = Matrix<T, PO2, Concrete>(A.rows(), A.cols(), false); 00473 perm_vec = Matrix<uint, PO3, Concrete> (A.rows() - 1, 1, false); 00474 00475 uint pivot; 00476 T temp; 00477 T sign = (T) 1; 00478 00479 for (uint k = 0; k < A.rows() - 1; ++k) { 00480 pivot = k; 00481 // find pivot 00482 for (uint i = k; i < A.rows(); ++i) { 00483 if (std::fabs(A(pivot,k)) < std::fabs(A(i,k))) 00484 pivot = i; 00485 } 00486 00487 SCYTHE_CHECK_20(A(pivot,k) == (T) 0, scythe_type_error, 00488 "Matrix is singular"); 00489 00490 // permute 00491 if (k != pivot) { 00492 sign *= -1; 00493 for (uint i = 0; i < A.rows(); ++i) { 00494 temp = A(pivot,i); 00495 A(pivot,i) = A(k,i); 00496 A(k,i) = temp; 00497 } 00498 } 00499 perm_vec[k] = pivot; 00500 00501 for (uint i = k + 1; i < A.rows(); ++i) { 00502 A(i,k) = A(i,k) / A(k,k); 00503 for (uint j = k + 1; j < A.rows(); ++j) 00504 A(i,j) = A(i,j) - A(i,k) * A(k,j); 00505 } 00506 } 00507 00508 L = A; 00509 00510 for (uint i = 0; i < A.rows(); ++i) { 00511 for (uint j = i; j < A.rows(); ++j) { 00512 U(i,j) = A(i,j); 00513 L(i,j) = (T) 0; 00514 L(i,i) = (T) 1; 00515 } 00516 } 00517 return sign; 00518 } 00519 } 00520 00521 /* Calculates the LU Decomposition of a square Matrix */ 00522 00523 /* Note that the L, U, and perm_vec must be concrete. A is passed by 00524 * value, because it is changed during the decomposition. If A is a 00525 * view, it will get mangled, but the decomposition will work fine. 00526 * Not sure what the copy/view access trade-off is, but passing a 00527 * view might speed things up if you don't care about messing up 00528 * your matrix. 00529 */ 00558 template <matrix_order PO1, matrix_style PS1, typename T, 00559 matrix_order PO2, matrix_order PO3, matrix_order PO4> 00560 void 00561 lu_decomp(Matrix<T,PO1,PS1> A, Matrix<T,PO2,Concrete>& L, 00562 Matrix<T,PO3,Concrete>& U, 00563 Matrix<unsigned int, PO4, Concrete>& perm_vec) 00564 { 00565 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00566 "A is NULL") 00567 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 00568 "Matrix A not square"); 00569 00570 lu_decomp_alg(A, L, U, perm_vec); 00571 } 00572 00573 /* lu_solve overloaded: you need A, b + L, U, perm_vec from 00574 * lu_decomp. 00575 * 00576 */ 00605 template <matrix_order RO, matrix_style RS, typename T, 00606 matrix_order PO1, matrix_style PS1, 00607 matrix_order PO2, matrix_style PS2, 00608 matrix_order PO3, matrix_style PS3, 00609 matrix_order PO4, matrix_style PS4, 00610 matrix_order PO5, matrix_style PS5> 00611 Matrix<T, RO, RS> 00612 lu_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b, 00613 const Matrix<T,PO3,PS3>& L, const Matrix<T,PO4,PS4>& U, 00614 const Matrix<unsigned int, PO5, PS5> &perm_vec) 00615 { 00616 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00617 "A is NULL") 00618 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, 00619 "b is not a column vector"); 00620 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 00621 "A is not square"); 00622 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, 00623 "A and b have different row sizes"); 00624 SCYTHE_CHECK_10(A.rows() != L.rows() || A.rows() != U.rows() || 00625 A.cols() != L.cols() || A.cols() != U.cols(), 00626 scythe_conformation_error, 00627 "A, L, and U do not conform"); 00628 SCYTHE_CHECK_10(perm_vec.rows() + 1 != A.rows(), 00629 scythe_conformation_error, 00630 "perm_vec does not have exactly one less row than A"); 00631 00632 T *y = new T[A.rows()]; 00633 T *x = new T[A.rows()]; 00634 00635 Matrix<T,RO,Concrete> bb = row_interchange(b, perm_vec); 00636 solve(L, U, bb, x, y); 00637 00638 Matrix<T,RO,RS> result(A.rows(), 1, x); 00639 00640 delete[]x; 00641 delete[]y; 00642 00643 return result; 00644 } 00645 00646 template <typename T, matrix_order PO1, matrix_style PS1, 00647 matrix_order PO2, matrix_style PS2, 00648 matrix_order PO3, matrix_style PS3, 00649 matrix_order PO4, matrix_style PS4, 00650 matrix_order PO5, matrix_style PS5> 00651 Matrix<T, PO1, Concrete> 00652 lu_solve (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& b, 00653 const Matrix<T,PO3,PS3>& L, const Matrix<T,PO4,PS4>& U, 00654 const Matrix<unsigned int, PO5, PS5> &perm_vec) 00655 { 00656 return lu_solve<PO1,Concrete>(A, b, L, U, perm_vec); 00657 } 00658 00679 template <matrix_order RO, matrix_style RS, typename T, 00680 matrix_order PO1, matrix_style PS1, 00681 matrix_order PO2, matrix_style PS2> 00682 Matrix<T,RO,RS> 00683 lu_solve (Matrix<T,PO1,PS1> A, const Matrix<T,PO2,PS2>& b) 00684 { 00685 // step 1 compute the LU factorization 00686 Matrix<T, RO, Concrete> L, U; 00687 Matrix<uint, RO, Concrete> perm_vec; 00688 lu_decomp_alg(A, L, U, perm_vec); 00689 00690 return lu_solve<RO,RS>(A, b, L, U, perm_vec); 00691 } 00692 00693 template <typename T, matrix_order PO1, matrix_style PS1, 00694 matrix_order PO2, matrix_style PS2> 00695 Matrix<T,PO1,Concrete> 00696 lu_solve (Matrix<T,PO1,PS1> A, const Matrix<T,PO2,PS2>& b) 00697 { 00698 // Slight code rep here, but very few lines 00699 00700 // step 1 compute the LU factorization 00701 Matrix<T, PO1, Concrete> L, U; 00702 Matrix<uint, PO1, Concrete> perm_vec; 00703 lu_decomp_alg(A, L, U, perm_vec); 00704 00705 return lu_solve<PO1,Concrete>(A, b, L, U, perm_vec); 00706 } 00707 00731 template<matrix_order RO, matrix_style RS, typename T, 00732 matrix_order PO1, matrix_style PS1, 00733 matrix_order PO2, matrix_style PS2, 00734 matrix_order PO3, matrix_style PS3, 00735 matrix_order PO4, matrix_style PS4> 00736 Matrix<T,RO,RS> 00737 inv (const Matrix<T,PO1,PS1>& A, 00738 const Matrix<T,PO2,PS2>& L, const Matrix<T,PO3,PS3>& U, 00739 const Matrix<unsigned int,PO4,PS4>& perm_vec) 00740 { 00741 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00742 "A is NULL") 00743 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, 00744 "A is not square"); 00745 SCYTHE_CHECK_10(A.rows() != L.rows() || A.rows() != U.rows() || 00746 A.cols() != L.cols() || A.cols() != U.cols(), 00747 scythe_conformation_error, 00748 "A, L, and U do not conform"); 00749 SCYTHE_CHECK_10(perm_vec.rows() + 1 != A.rows() 00750 && !(A.isScalar() && perm_vec.isScalar()), 00751 scythe_conformation_error, 00752 "perm_vec does not have exactly one less row than A"); 00753 00754 // For the final result 00755 Matrix<T,RO,Concrete> Ainv(A.rows(), A.rows(), false); 00756 00757 // for the solve block 00758 T *y = new T[A.rows()]; 00759 T *x = new T[A.rows()]; 00760 Matrix<T, RO, Concrete> b(A.rows(), 1); // full of zeros 00761 Matrix<T,RO,Concrete> bb; 00762 00763 for (uint k = 0; k < A.rows(); ++k) { 00764 b[k] = (T) 1; 00765 bb = row_interchange(b, perm_vec); 00766 00767 solve(L, U, bb, x, y); 00768 00769 b[k] = (T) 0; 00770 for (uint l = 0; l < A.rows(); ++l) 00771 Ainv(l,k) = x[l]; 00772 } 00773 00774 delete[] y; 00775 delete[] x; 00776 00777 SCYTHE_VIEW_RETURN(T, RO, RS, Ainv) 00778 } 00779 00780 template<typename T, 00781 matrix_order PO1, matrix_style PS1, 00782 matrix_order PO2, matrix_style PS2, 00783 matrix_order PO3, matrix_style PS3, 00784 matrix_order PO4, matrix_style PS4> 00785 Matrix<T,PO1,Concrete> 00786 inv (const Matrix<T,PO1,PS1>& A, 00787 const Matrix<T,PO2,PS2>& L, const Matrix<T,PO3,PS3>& U, 00788 const Matrix<unsigned int,PO4,PS4>& perm_vec) 00789 { 00790 return inv<PO1,Concrete>(A, L, U, perm_vec); 00791 } 00792 00811 template <matrix_order RO, matrix_style RS, typename T, 00812 matrix_order PO, matrix_style PS> 00813 Matrix<T, RO, RS> 00814 inv (const Matrix<T, PO, PS>& A) 00815 { 00816 // Make a copy of A for the decomposition (do it with an explicit 00817 // copy to a concrete case A is a view) 00818 Matrix<T,RO,Concrete> AA = A; 00819 00820 // step 1 compute the LU factorization 00821 Matrix<T, RO, Concrete> L, U; 00822 Matrix<uint, RO, Concrete> perm_vec; 00823 lu_decomp_alg(AA, L, U, perm_vec); 00824 00825 return inv<RO,RS>(A, L, U, perm_vec); 00826 } 00827 00828 template <typename T, matrix_order O, matrix_style S> 00829 Matrix<T, O, Concrete> 00830 inv (const Matrix<T, O, S>& A) 00831 { 00832 return inv<O,Concrete>(A); 00833 } 00834 00835 /* Interchanges the rows of A with those in vector p */ 00852 template <matrix_order RO, matrix_style RS, typename T, 00853 matrix_order PO1, matrix_style PS1, 00854 matrix_order PO2, matrix_style PS2> 00855 Matrix<T,RO,RS> 00856 row_interchange (Matrix<T,PO1,PS1> A, 00857 const Matrix<unsigned int,PO2,PS2>& p) 00858 { 00859 SCYTHE_CHECK_10(! p.isColVector(), scythe_dimension_error, 00860 "p not a column vector"); 00861 SCYTHE_CHECK_10(p.rows() + 1 != A.rows() && ! p.isScalar(), 00862 scythe_conformation_error, "p must have one less row than A"); 00863 00864 for (uint i = 0; i < A.rows() - 1; ++i) { 00865 Matrix<T,PO1,View> vec1 = A(i, _); 00866 Matrix<T,PO1,View> vec2 = A(p[i], _); 00867 std::swap_ranges(vec1.begin_f(), vec1.end_f(), vec2.begin_f()); 00868 } 00869 00870 return A; 00871 } 00872 00873 template <typename T, matrix_order PO1, matrix_style PS1, 00874 matrix_order PO2, matrix_style PS2> 00875 Matrix<T,PO1,Concrete> 00876 row_interchange (const Matrix<T,PO1,PS1>& A, 00877 const Matrix<unsigned int,PO2,PS2>& p) 00878 { 00879 return row_interchange<PO1,Concrete>(A, p); 00880 } 00881 00894 template <typename T, matrix_order PO, matrix_style PS> 00895 T 00896 det (const Matrix<T, PO, PS>& A) 00897 { 00898 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 00899 "Matrix is not square") 00900 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00901 "Matrix is NULL") 00902 00903 // Make a copy of A for the decomposition (do it here instead of 00904 // at parameter pass in case A is a view) 00905 Matrix<T,PO,Concrete> AA = A; 00906 00907 // step 1 compute the LU factorization 00908 Matrix<T, PO, Concrete> L, U; 00909 Matrix<uint, PO, Concrete> perm_vec; 00910 T sign = lu_decomp_alg(AA, L, U, perm_vec); 00911 00912 // step 2 calculate the product of diag(U) and sign 00913 T det = (T) 1; 00914 for (uint i = 0; i < AA.rows(); ++i) 00915 det *= AA(i, i); 00916 00917 return sign * det; 00918 } 00919 00920 #ifdef SCYTHE_LAPACK 00921 00922 template<> 00923 inline Matrix<> 00924 cholesky (const Matrix<>& A) 00925 { 00926 SCYTHE_DEBUG_MSG("Using lapack/blas for cholesky"); 00927 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 00928 "Matrix not square"); 00929 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00930 "Matrix is NULL"); 00931 00932 // We have to do an explicit copy within the func to match the 00933 // template declaration of the more general template. 00934 Matrix<> AA = A; 00935 00936 // Get a pointer to the internal array and set up some vars 00937 double* Aarray = AA.getArray(); // internal array pointer 00938 int rows = (int) AA.rows(); // the dim of the matrix 00939 int err = 0; // The output error condition 00940 00941 // Cholesky decomposition step 00942 lapack::dpotrf_("L", &rows, Aarray, &rows, &err); 00943 SCYTHE_CHECK_10(err > 0, scythe_type_error, 00944 "Matrix is not positive definite") 00945 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 00946 "The " << err << "th value of the matrix had an illegal value") 00947 00948 // Zero out upper triangle 00949 for (uint j = 1; j < AA.cols(); ++j) 00950 for (uint i = 0; i < j; ++i) 00951 AA(i, j) = 0; 00952 00953 return AA; 00954 } 00955 00956 template<> 00957 inline Matrix<> 00958 chol_solve (const Matrix<>& A, const Matrix<>& b, const Matrix<>& M) 00959 { 00960 SCYTHE_DEBUG_MSG("Using lapack/blas for chol_solve"); 00961 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00962 "A is NULL") 00963 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, 00964 "b must be a column vector"); 00965 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, 00966 "A and b do not conform"); 00967 SCYTHE_CHECK_10(A.rows() != M.rows(), scythe_conformation_error, 00968 "A and M do not conform"); 00969 SCYTHE_CHECK_10(! M.isSquare(), scythe_dimension_error, 00970 "M must be square"); 00971 00972 // The algorithm modifies b in place. We make a copy. 00973 Matrix<> bb = b; 00974 00975 // Get array pointers and set up some vars 00976 const double* Marray = M.getArray(); 00977 double* barray = bb.getArray(); 00978 int rows = (int) bb.rows(); 00979 int cols = (int) bb.cols(); // currently always one, but generalizable 00980 int err = 0; 00981 00982 // Solve the system 00983 lapack::dpotrs_("L", &rows, &cols, Marray, &rows, barray, &rows, &err); 00984 SCYTHE_CHECK_10(err > 0, scythe_type_error, 00985 "Matrix is not positive definite") 00986 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 00987 "The " << err << "th value of the matrix had an illegal value") 00988 00989 return bb; 00990 } 00991 00992 template<> 00993 inline Matrix<> 00994 chol_solve (const Matrix<>& A, const Matrix<>& b) 00995 { 00996 SCYTHE_DEBUG_MSG("Using lapack/blas for chol_solve"); 00997 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 00998 "A is NULL") 00999 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, 01000 "b must be a column vector"); 01001 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, 01002 "A and b do not conform"); 01003 01004 // The algorithm modifies both A and b in place, so we make copies 01005 Matrix<> AA =A; 01006 Matrix<> bb = b; 01007 01008 // Get array pointers and set up some vars 01009 double* Aarray = AA.getArray(); 01010 double* barray = bb.getArray(); 01011 int rows = (int) bb.rows(); 01012 int cols = (int) bb.cols(); // currently always one, but generalizable 01013 int err = 0; 01014 01015 // Solve the system 01016 lapack::dposv_("L", &rows, &cols, Aarray, &rows, barray, &rows, &err); 01017 SCYTHE_CHECK_10(err > 0, scythe_type_error, 01018 "Matrix is not positive definite") 01019 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 01020 "The " << err << "th value of the matrix had an illegal value") 01021 01022 return bb; 01023 } 01024 01025 template <matrix_order PO2, matrix_order PO3, matrix_order PO4> 01026 inline double 01027 lu_decomp_alg(Matrix<>& A, Matrix<double,PO2,Concrete>& L, 01028 Matrix<double,PO3,Concrete>& U, 01029 Matrix<unsigned int, PO4, Concrete>& perm_vec) 01030 { 01031 SCYTHE_DEBUG_MSG("Using lapack/blas for lu_decomp_alg"); 01032 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") 01033 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, 01034 "A is not square"); 01035 01036 if (A.isRowVector()) { 01037 L = Matrix<double,PO2,Concrete> (1, 1, true, 1); // all 1s 01038 U = A; 01039 perm_vec = Matrix<uint, PO4, Concrete>(1, 1); // all 0s 01040 return 0.; 01041 } 01042 01043 L = U = Matrix<double, PO2, Concrete>(A.rows(), A.cols(), false); 01044 perm_vec = Matrix<uint, PO3, Concrete> (A.rows(), 1, false); 01045 01046 // Get a pointer to the internal array and set up some vars 01047 double* Aarray = A.getArray(); // internal array pointer 01048 int rows = (int) A.rows(); // the dim of the matrix 01049 int* ipiv = (int*) perm_vec.getArray(); // Holds the lu decomp pivot array 01050 int err = 0; // The output error condition 01051 01052 // Do the decomposition 01053 lapack::dgetrf_(&rows, &rows, Aarray, &rows, ipiv, &err); 01054 01055 SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular"); 01056 SCYTHE_CHECK_10(err < 0, scythe_lapack_internal_error, 01057 "The " << err << "th value of the matrix had an illegal value"); 01058 01059 // Now fill in the L and U matrices. 01060 L = A; 01061 for (uint i = 0; i < A.rows(); ++i) { 01062 for (uint j = i; j < A.rows(); ++j) { 01063 U(i,j) = A(i,j); 01064 L(i,j) = 0.; 01065 L(i,i) = 1.; 01066 } 01067 } 01068 01069 // Change to scythe's rows-1 perm_vec format and c++ indexing 01070 // XXX Cutting off the last pivot term may be buggy if it isn't 01071 // always just pointing at itself 01072 if (perm_vec(perm_vec.size() - 1) != perm_vec.size()) 01073 SCYTHE_THROW(scythe_unexpected_default_error, 01074 "This is an unexpected error. Please notify the developers.") 01075 perm_vec = perm_vec(0, 0, perm_vec.rows() - 2, 0) - 1; 01076 01077 // Finally, figure out the sign of perm_vec 01078 if (sum(perm_vec > 0) % 2 == 0) 01079 return 1; 01080 01081 return -1; 01082 } 01083 01103 struct QRdecomp { 01104 Matrix<> QR; 01105 Matrix<> tau; 01106 Matrix<> pivot; 01107 }; 01108 01142 inline QRdecomp 01143 qr_decomp (const Matrix<>& A) 01144 { 01145 SCYTHE_DEBUG_MSG("Using lapack/blas for qr_decomp"); 01146 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL"); 01147 01148 // Set up working variables 01149 Matrix<> QR = A; 01150 double* QRarray = QR.getArray(); // input/output array pointer 01151 int rows = (int) QR.rows(); 01152 int cols = (int) QR.cols(); 01153 Matrix<unsigned int> pivot(cols, 1); // pivot vector 01154 int* parray = (int*) pivot.getArray(); // pivot vector array pointer 01155 Matrix<> tau = Matrix<>(rows < cols ? rows : cols, 1); 01156 double* tarray = tau.getArray(); // tau output array pointer 01157 double tmp, *work; // workspace vars 01158 int lwork, info; // workspace size var and error info var 01159 01160 // Get workspace size 01161 lwork = -1; 01162 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, &tmp, 01163 &lwork, &info); 01164 01165 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01166 "Internal error in LAPACK routine dgeqp3"); 01167 01168 lwork = (int) tmp; 01169 work = new double[lwork]; 01170 01171 // run the routine for real 01172 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, work, 01173 &lwork, &info); 01174 01175 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01176 "Internal error in LAPACK routine dgeqp3"); 01177 01178 delete[] work; 01179 01180 pivot -= 1; 01181 01182 QRdecomp result; 01183 result.QR = QR; 01184 result.tau = tau; 01185 result.pivot = pivot; 01186 01187 return result; 01188 } 01189 01226 inline Matrix<> 01227 qr_solve(const Matrix<>& A, const Matrix<>& b, const QRdecomp& QR) 01228 { 01229 SCYTHE_DEBUG_MSG("Using lapack/blas for qr_solve"); 01230 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") 01231 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, 01232 "A and b do not conform"); 01233 SCYTHE_CHECK_10(A.rows() != QR.QR.rows() || A.cols() != QR.QR.cols(), 01234 scythe_conformation_error, "A and QR do not conform"); 01235 int taudim = (int) (A.rows() < A.cols() ? A.rows() : A.cols()); 01236 SCYTHE_CHECK_10(QR.tau.size() != taudim, scythe_conformation_error, 01237 "A and tau do not conform"); 01238 SCYTHE_CHECK_10(QR.pivot.size() != A.cols(), scythe_conformation_error, 01239 "pivot vector is not the right length"); 01240 01241 int rows = (int) QR.QR.rows(); 01242 int cols = (int) QR.QR.cols(); 01243 int nrhs = (int) b.cols(); 01244 int lwork, info; 01245 double *work, tmp; 01246 double* QRarray = QR.QR.getArray(); 01247 double* tarray = QR.tau.getArray(); 01248 Matrix<> bb = b; 01249 double* barray = bb.getArray(); 01250 01251 // Get workspace size 01252 lwork = -1; 01253 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, 01254 tarray, barray, &rows, &tmp, &lwork, &info); 01255 01256 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01257 "Internal error in LAPACK routine dormqr"); 01258 01259 // And now for real 01260 lwork = (int) tmp; 01261 work = new double[lwork]; 01262 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, 01263 tarray, barray, &rows, work, &lwork, &info); 01264 01265 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01266 "Internal error in LAPACK routine dormqr"); 01267 01268 lapack::dtrtrs_("U", "N", "N", &taudim, &nrhs, QRarray, &rows, barray, 01269 &rows, &info); 01270 01271 SCYTHE_CHECK_10(info > 0, scythe_type_error, "Matrix is singular"); 01272 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error, 01273 "Internal error in LAPACK routine dtrtrs"); 01274 01275 delete[] work; 01276 01277 Matrix<> result(A.cols(), b.cols(), false); 01278 for (uint i = 0; i < QR.pivot.size(); ++i) 01279 result(i, _) = bb((uint) QR.pivot(i), _); 01280 return result; 01281 } 01282 01314 inline Matrix<> 01315 qr_solve (const Matrix<>& A, const Matrix<>& b) 01316 { 01317 SCYTHE_DEBUG_MSG("Using lapack/blas for qr_solve"); 01318 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") 01319 SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, 01320 "A and b do not conform"); 01321 01322 /* Do decomposition */ 01323 01324 // Set up working variables 01325 Matrix<> QR = A; 01326 double* QRarray = QR.getArray(); // input/output array pointer 01327 int rows = (int) QR.rows(); 01328 int cols = (int) QR.cols(); 01329 Matrix<unsigned int> pivot(cols, 1); // pivot vector 01330 int* parray = (int*) pivot.getArray(); // pivot vector array pointer 01331 Matrix<> tau = Matrix<>(rows < cols ? rows : cols, 1); 01332 double* tarray = tau.getArray(); // tau output array pointer 01333 double tmp, *work; // workspace vars 01334 int lwork, info; // workspace size var and error info var 01335 01336 // Get workspace size 01337 lwork = -1; 01338 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, &tmp, 01339 &lwork, &info); 01340 01341 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01342 "Internal error in LAPACK routine dgeqp3"); 01343 01344 lwork = (int) tmp; 01345 work = new double[lwork]; 01346 01347 // run the routine for real 01348 lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, work, 01349 &lwork, &info); 01350 01351 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01352 "Internal error in LAPACK routine dgeqp3"); 01353 01354 delete[] work; 01355 01356 pivot -= 1; 01357 01358 /* Now solve the system */ 01359 01360 // working vars 01361 int nrhs = (int) b.cols(); 01362 Matrix<> bb = b; 01363 double* barray = bb.getArray(); 01364 int taudim = (int) tau.size(); 01365 01366 // Get workspace size 01367 lwork = -1; 01368 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, 01369 tarray, barray, &rows, &tmp, &lwork, &info); 01370 01371 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01372 "Internal error in LAPACK routine dormqr"); 01373 01374 // And now for real 01375 lwork = (int) tmp; 01376 work = new double[lwork]; 01377 lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, 01378 tarray, barray, &rows, work, &lwork, &info); 01379 01380 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01381 "Internal error in LAPACK routine dormqr"); 01382 01383 lapack::dtrtrs_("U", "N", "N", &taudim, &nrhs, QRarray, &rows, barray, 01384 &rows, &info); 01385 01386 SCYTHE_CHECK_10(info > 0, scythe_type_error, "Matrix is singular"); 01387 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error, 01388 "Internal error in LAPACK routine dtrtrs"); 01389 01390 delete[] work; 01391 01392 Matrix<> result(A.cols(), b.cols(), false); 01393 for (uint i = 0; i < pivot.size(); ++i) 01394 result(i, _) = bb(pivot(i), _); 01395 01396 return result; 01397 } 01398 01399 template<> 01400 inline Matrix<> 01401 invpd (const Matrix<>& A) 01402 { 01403 SCYTHE_DEBUG_MSG("Using lapack/blas for invpd"); 01404 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 01405 "A is NULL") 01406 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, 01407 "A is not square"); 01408 01409 // We have to do an explicit copy within the func to match the 01410 // template declaration of the more general template. 01411 Matrix<> AA = A; 01412 01413 // Get a pointer to the internal array and set up some vars 01414 double* Aarray = AA.getArray(); // internal array pointer 01415 int rows = (int) AA.rows(); // the dim of the matrix 01416 int err = 0; // The output error condition 01417 01418 // Cholesky decomposition step 01419 lapack::dpotrf_("L", &rows, Aarray, &rows, &err); 01420 SCYTHE_CHECK_10(err > 0, scythe_type_error, 01421 "Matrix is not positive definite") 01422 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 01423 "The " << err << "th value of the matrix had an illegal value") 01424 01425 // Inversion step 01426 lapack::dpotri_("L", &rows, Aarray, &rows, &err); 01427 SCYTHE_CHECK_10(err > 0, scythe_type_error, 01428 "The (" << err << ", " << err << ") element of the matrix is zero" 01429 << " and the inverse could not be computed") 01430 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 01431 "The " << err << "th value of the matrix had an illegal value") 01432 lapack::make_symmetric(Aarray, rows); 01433 01434 return AA; 01435 } 01436 01437 template<> 01438 inline Matrix<> 01439 invpd (const Matrix<>& A, const Matrix<>& M) 01440 { 01441 SCYTHE_DEBUG_MSG("Using lapack/blas for invpd"); 01442 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 01443 "A is NULL") 01444 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, 01445 "A is not square"); 01446 SCYTHE_CHECK_10(A.rows() != M.cols() || A.cols() != M.rows(), 01447 scythe_conformation_error, "A and M do not conform"); 01448 01449 // We have to do an explicit copy within the func to match the 01450 // template declaration of the more general template. 01451 Matrix<> MM = M; 01452 01453 // Get pointer and set up some vars 01454 double* Marray = MM.getArray(); 01455 int rows = (int) MM.rows(); 01456 int err = 0; 01457 01458 // Inversion step 01459 lapack::dpotri_("L", &rows, Marray, &rows, &err); 01460 SCYTHE_CHECK_10(err > 0, scythe_type_error, 01461 "The (" << err << ", " << err << ") element of the matrix is zero" 01462 << " and the inverse could not be computed") 01463 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 01464 "The " << err << "th value of the matrix had an illegal value") 01465 lapack::make_symmetric(Marray, rows); 01466 01467 return MM; 01468 } 01469 01470 template <> 01471 inline Matrix<> 01472 inv(const Matrix<>& A) 01473 { 01474 SCYTHE_DEBUG_MSG("Using lapack/blas for inv"); 01475 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 01476 "A is NULL") 01477 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, 01478 "A is not square"); 01479 01480 // We have to do an explicit copy within the func to match the 01481 // template declaration of the more general template. 01482 Matrix<> AA = A; 01483 01484 // Get a pointer to the internal array and set up some vars 01485 double* Aarray = AA.getArray(); // internal array pointer 01486 int rows = (int) AA.rows(); // the dim of the matrix 01487 int* ipiv = new int[rows]; // Holds the lu decomp pivot array 01488 int err = 0; // The output error condition 01489 01490 // LU decomposition step 01491 lapack::dgetrf_(&rows, &rows, Aarray, &rows, ipiv, &err); 01492 01493 SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular"); 01494 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 01495 "The " << err << "th value of the matrix had an illegal value"); 01496 01497 // Inversion step; first do a workspace query, then the actual 01498 // inversion 01499 double work_query = 0; 01500 int work_size = -1; 01501 lapack::dgetri_(&rows, Aarray, &rows, ipiv, &work_query, 01502 &work_size, &err); 01503 double* workspace = new double[(work_size = (int) work_query)]; 01504 lapack::dgetri_(&rows, Aarray, &rows, ipiv, workspace, &work_size, 01505 &err); 01506 delete[] ipiv; 01507 delete[] workspace; 01508 01509 SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular"); 01510 SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, 01511 "Internal error in LAPACK routine dgetri"); 01512 01513 return AA; 01514 } 01515 01531 struct SVD { 01532 Matrix<> d; // singular values 01533 Matrix<> U; // left singular vectors 01534 Matrix<> Vt; // transpose of right singular vectors 01535 }; 01536 01565 inline SVD 01566 svd (const Matrix<>& A, int nu = -1, int nv = -1) 01567 { 01568 SCYTHE_DEBUG_MSG("Using lapack/blas for eigen"); 01569 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 01570 "Matrix is NULL"); 01571 01572 char* jobz; 01573 int m = (int) A.rows(); 01574 int n = (int) A.cols(); 01575 int mn = (int) std::min(A.rows(), A.cols()); 01576 Matrix<> U; 01577 Matrix<> V; 01578 if (nu < 0) nu = mn; 01579 if (nv < 0) nv = mn; 01580 if (nu <= mn && nv<= mn) { 01581 jobz = "S"; 01582 U = Matrix<>(m, mn, false); 01583 V = Matrix<>(mn, n, false); 01584 } else if (nu == 0 && nv == 0) { 01585 jobz = "N"; 01586 } else { 01587 jobz = "A"; 01588 U = Matrix<>(m, m, false); 01589 V = Matrix<>(n, n, false); 01590 } 01591 double* Uarray = U.getArray(); 01592 double* Varray = V.getArray(); 01593 01594 int ldu = (int) U.rows(); 01595 int ldvt = (int) V.rows(); 01596 Matrix<> X = A; 01597 double* Xarray = X.getArray(); 01598 Matrix<> d(mn, 1, false); 01599 double* darray = d.getArray(); 01600 01601 double tmp, *work; 01602 int lwork, info; 01603 int *iwork = new int[8 * mn]; 01604 01605 // get optimal workspace 01606 lwork = -1; 01607 lapack::dgesdd_(jobz, &m, &n, Xarray, &m, darray, Uarray, &ldu, 01608 Varray, &ldvt, &tmp, &lwork, iwork, &info); 01609 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error, 01610 "Internal error in LAPACK routine dgessd"); 01611 SCYTHE_CHECK_10(info > 0, scythe_convergence_error, "Did not converge"); 01612 01613 lwork = (int) tmp; 01614 work = new double[lwork]; 01615 01616 // Now for real 01617 lapack::dgesdd_(jobz, &m, &n, Xarray, &m, darray, Uarray, &ldu, 01618 Varray, &ldvt, work, &lwork, iwork, &info); 01619 SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error, 01620 "Internal error in LAPACK routine dgessd"); 01621 SCYTHE_CHECK_10(info > 0, scythe_convergence_error, "Did not converge"); 01622 delete[] work; 01623 01624 if (nu < mn && nu > 0) 01625 U = U(0, 0, U.rows() - 1, (unsigned int) std::min(m, nu) - 1); 01626 if (nv < mn && nv > 0) 01627 V = V(0, 0, (unsigned int) std::min(n, nv) - 1, V.cols() - 1); 01628 SVD result; 01629 result.d = d; 01630 result.U = U; 01631 result.Vt = V; 01632 01633 return result; 01634 } 01635 01647 struct Eigen { 01648 Matrix<> values; 01649 Matrix<> vectors; 01650 }; 01651 01680 inline Eigen 01681 eigen (const Matrix<>& A, bool vectors=true) 01682 { 01683 SCYTHE_DEBUG_MSG("Using lapack/blas for eigen"); 01684 SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, 01685 "Matrix not square"); 01686 SCYTHE_CHECK_10(A.isNull(), scythe_null_error, 01687 "Matrix is NULL"); 01688 // Should be symmetric but rounding errors make checking for this 01689 // difficult. 01690 01691 // Make a copy of A 01692 Matrix<> AA = A; 01693 01694 // Get a point to the internal array and set up some vars 01695 double* Aarray = AA.getArray(); // internal array points 01696 int order = (int) AA.rows(); // input matrix is order x order 01697 double dignored = 0; // we don't use this option 01698 int iignored = 0; // or this one 01699 double abstol = 0.0; // tolerance (default) 01700 int m; // output value 01701 Matrix<> result; // result matrix 01702 char getvecs[1]; // are we getting eigenvectors? 01703 if (vectors) { 01704 getvecs[0] = 'V'; 01705 result = Matrix<>(order, order + 1, false); 01706 } else { 01707 result = Matrix<>(order, 1, false); 01708 getvecs[0] = 'N'; 01709 } 01710 double* eigenvalues = result.getArray(); // pointer to result array 01711 int* isuppz = new int[2 * order]; // indices of nonzero eigvecs 01712 double tmp; // inital temporary value for getting work-space info 01713 int lwork, liwork, *iwork, itmp; // stuff for workspace 01714 double *work; // and more stuff for workspace 01715 int info = 0; // error code holder 01716 01717 // get optimal size for work arrays 01718 lwork = -1; 01719 liwork = -1; 01720 lapack::dsyevr_(getvecs, "A", "L", &order, Aarray, &order, &dignored, 01721 &dignored, &iignored, &iignored, &abstol, &m, eigenvalues, 01722 eigenvalues + order, &order, isuppz, &tmp, &lwork, &itmp, 01723 &liwork, &info); 01724 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01725 "Internal error in LAPACK routine dsyevr"); 01726 lwork = (int) tmp; 01727 liwork = itmp; 01728 work = new double[lwork]; 01729 iwork = new int[liwork]; 01730 01731 // do the actual operation 01732 lapack::dsyevr_(getvecs, "A", "L", &order, Aarray, &order, &dignored, 01733 &dignored, &iignored, &iignored, &abstol, &m, eigenvalues, 01734 eigenvalues + order, &order, isuppz, work, &lwork, iwork, 01735 &liwork, &info); 01736 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01737 "Internal error in LAPACK routine dsyevr"); 01738 01739 delete[] isuppz; 01740 delete[] work; 01741 delete[] iwork; 01742 01743 Eigen resobj; 01744 if (vectors) { 01745 resobj.values = result(_, 0); 01746 resobj.vectors = result(0, 1, result.rows() -1, result.cols() - 1); 01747 } else { 01748 resobj.values = result; 01749 } 01750 01751 return resobj; 01752 } 01753 01754 01755 struct GeneralEigen { 01756 Matrix<std::complex<double> > values; 01757 Matrix<> vectors; 01758 }; 01759 01760 inline GeneralEigen 01761 geneigen (const Matrix<>& A, bool vectors=true) 01762 { 01763 SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, 01764 "Matrix not square"); 01765 SCYTHE_CHECK_10 (A.isNull(), scythe_null_error, "Matrix is NULL"); 01766 01767 Matrix<> AA = A; // Copy A 01768 01769 // Get a point to the internal array and set up some vars 01770 double* Aarray = AA.getArray(); // internal array points 01771 int order = (int) AA.rows(); // input matrix is order x order 01772 01773 GeneralEigen result; 01774 01775 int info, lwork; 01776 double *left, *right, *valreal, *valimag, *work, tmp; 01777 valreal = new double[order]; 01778 valimag = new double[order]; 01779 left = right = (double *) 0; 01780 char leftvecs[1], rightvecs[1]; 01781 leftvecs[0] = rightvecs[0] = 'N'; 01782 if (vectors) { 01783 rightvecs[0] = 'V'; 01784 result.vectors = Matrix<>(order, order, false); 01785 right = result.vectors.getArray(); 01786 } 01787 01788 // Get working are size 01789 lwork = -1; 01790 lapack::dgeev_ (leftvecs, rightvecs, &order, Aarray, &order, 01791 valreal, valimag, left, &order, right, &order, 01792 &tmp, &lwork, &info); 01793 01794 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01795 "Internal error in LAPACK routine dgeev"); 01796 lwork = (int) tmp; 01797 work = new double[lwork]; 01798 01799 // Run for real 01800 lapack::dgeev_ (leftvecs, rightvecs, &order, Aarray, &order, 01801 valreal, valimag, left, &order, right, &order, 01802 work, &lwork, &info); 01803 SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, 01804 "Internal error in LAPACK routine dgeev"); 01805 01806 // Pack value into result 01807 result.values = Matrix<std::complex<double> > (order, 1, false); 01808 for (unsigned int i = 0; i < result.values.size(); ++i) 01809 result.values(i) = std::complex<double> (valreal[i], valimag[i]); 01810 01811 // Clean up 01812 delete[] valreal; 01813 delete[] valimag; 01814 delete[] work; 01815 01816 01817 return result; 01818 } 01819 01820 01821 01822 #endif 01823 01824 } // end namespace scythe 01825 01826 #endif /* SCYTHE_IDE_H */