Scythe-1.0.3
|
00001 /* 00002 * Scythe Statistical Library Copyright (C) 2000-2002 Andrew D. Martin 00003 * and Kevin M. Quinn; 2002-present Andrew D. Martin, Kevin M. Quinn, 00004 * and Daniel Pemstein. All Rights Reserved. 00005 * 00006 * This program is free software; you can redistribute it and/or 00007 * modify under the terms of the GNU General Public License as 00008 * published by Free Software Foundation; either version 2 of the 00009 * License, or (at your option) any later version. See the text files 00010 * COPYING and LICENSE, distributed with this source code, for further 00011 * information. 00012 * -------------------------------------------------------------------- 00013 * scythestat/rng/rtmvnorm.h 00014 * 00015 */ 00016 00028 #ifndef SCYTHE_RTMVNORM_H 00029 #define SCYTHE_RTMVNORM_H 00030 00031 #include <iostream> 00032 #include <cmath> 00033 00034 #ifdef SCYTHE_COMPILE_DIRECT 00035 #include "matrix.h" 00036 #include "rng.h" 00037 #include "error.h" 00038 #include "algorithm.h" 00039 #include "ide.h" 00040 #else 00041 #include "scythestat/matrix.h" 00042 #include "scythestat/rng.h" 00043 #include "scythestat/error.h" 00044 #include "scythestat/algorithm.h" 00045 #include "scythestat/ide.h" 00046 #endif 00047 namespace scythe 00048 { 00049 /* Truncated Multivariate Normal Distribution by Gibbs sampling 00050 * (Geweke 1991). This is a functor that allows one to 00051 * initialize---and optionally burn in---a sampler for a given 00052 * truncated multivariate normal distribution on construction 00053 * and then make (optionally thinned) draws with calls to the () 00054 * operator. 00055 * 00056 */ 00067 template <class RNGTYPE> 00068 class rtmvnorm { 00069 public: 00070 00120 template <matrix_order PO1, matrix_style PS1, matrix_order PO2, 00121 matrix_style PS2, matrix_order PO3, matrix_style PS3, 00122 matrix_order PO4, matrix_style PS5, matrix_order PO5, 00123 matrix_style PS4> 00124 rtmvnorm (const Matrix<double, PO1,PS1>& mu, 00125 const Matrix<double, PO2, PS2>& sigma, 00126 const Matrix<double, PO3, PS3>& D, 00127 const Matrix<double, PO4, PS4>& a, 00128 const Matrix<double, PO5, PS5>& b, rng<RNGTYPE>& generator, 00129 unsigned int burnin = 0, unsigned int thin = 1, 00130 bool preinvertedD = false) 00131 : mu_ (mu), C_ (mu.rows(), mu.rows(), false), 00132 h_ (mu.rows(), 1, false), z_ (mu.rows(), 1, true, 0), 00133 generator_ (generator), n_ (mu.rows()), thin_ (thin), iter_ (0) 00134 { 00135 SCYTHE_CHECK_10(thin == 0, scythe_invalid_arg, 00136 "thin must be >= 1"); 00137 SCYTHE_CHECK_10(! mu.isColVector(), scythe_dimension_error, 00138 "mu not column vector"); 00139 SCYTHE_CHECK_10(! sigma.isSquare(), scythe_dimension_error, 00140 "sigma not square"); 00141 SCYTHE_CHECK_10(! D.isSquare(), scythe_dimension_error, 00142 "D not square"); 00143 SCYTHE_CHECK_10(! a.isColVector(), scythe_dimension_error, 00144 "a not column vector"); 00145 SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, 00146 "b not column vector"); 00147 SCYTHE_CHECK_10(sigma.rows() != n_ || D.rows() != n_ || 00148 a.rows() != n_ || b.rows() != n_, scythe_conformation_error, 00149 "mu, sigma, D, a, and b not conformable"); 00150 00151 // TODO will D * sigma * t(D) always be positive definite, 00152 // allowing us to use the faster invpd? 00153 if (preinvertedD) 00154 Dinv_ = D; 00155 else 00156 Dinv_ = inv(D); 00157 Matrix<> Tinv = inv(D * sigma * t(D)); 00158 alpha_ = a - D * mu; 00159 beta_ = b - D * mu; 00160 00161 // Check truncation bounds 00162 if (SCYTHE_DEBUG > 0) { 00163 for (unsigned int i = 0; i < n_; ++i) { 00164 SCYTHE_CHECK(alpha_(i) >= beta_(i), scythe_invalid_arg, 00165 "Truncation bound " << i 00166 << " not logically consistent"); 00167 } 00168 } 00169 00170 // Precompute some stuff (see Geweke 1991 pg 7). 00171 for (unsigned int i = 0; i < n_; ++i) { 00172 C_(i, _) = -(1 / Tinv(i, i)) % Tinv(i, _); 00173 C_(i, i) = 0; // not really clever but probably too clever 00174 h_(i) = std::sqrt(1 / Tinv(i, i)); 00175 SCYTHE_CHECK_30(std::isnan(h_(i)), scythe_invalid_arg, 00176 "sigma is not positive definite"); 00177 } 00178 00179 // Do burnin 00180 for (unsigned int i = 0; i < burnin; ++i) 00181 sample (); 00182 } 00183 00193 template <matrix_order O, matrix_style S> 00194 Matrix<double, O, S> operator() () 00195 { 00196 do { sample (); } while (iter_ % thin_ != 0); 00197 00198 return (mu_ + Dinv_ * z_); 00199 } 00200 00207 Matrix<double,Col,Concrete> operator() () 00208 { 00209 return operator()<Col, Concrete>(); 00210 } 00211 00212 protected: 00213 /* Does one step of the Gibbs sampler (see Geweke 1991 p 6) */ 00214 void sample () 00215 { 00216 double czsum; 00217 double above; 00218 double below; 00219 for (unsigned int i = 0; i < n_; ++i) { 00220 00221 // Calculate sum_{j \ne i} c_{ij} z_{j} 00222 czsum = 0; 00223 for (unsigned int j = 0; j < n_; ++j) { 00224 if (i == j) continue; 00225 czsum += C_(i, j) * z_(j); 00226 } 00227 00228 // Calc truncation of conditional univariate std normal 00229 below = (alpha_(i) - czsum) / h_(i); 00230 above = (beta_(i) - czsum) / h_(i); 00231 00232 // Draw random variate z_i 00233 z_(i) = h_(i); 00234 if (above == std::numeric_limits<double>::infinity()){ 00235 if (below == -std::numeric_limits<double>::infinity()) 00236 z_(i) *= generator_.rnorm(0, 1); // untruncated 00237 else 00238 z_(i) *= generator_.rtbnorm_combo(0, 1, below); 00239 } else if (below == 00240 -std::numeric_limits<double>::infinity()) 00241 z_(i) *= generator_.rtanorm_combo(0, 1, above); 00242 else 00243 z_(i) *= generator_.rtnorm_combo(0, 1, below, above); 00244 00245 z_(i) += czsum; 00246 } 00247 00248 ++iter_; 00249 } 00250 00251 /* Instance variables */ 00252 // Various reused computation matrices with names from 00253 // Geweke 1991. 00254 Matrix<> mu_; Matrix<> Dinv_; 00255 Matrix<> C_; Matrix<> alpha_; Matrix<> beta_; Matrix<> h_; 00256 00257 Matrix<> z_; // The current draw of the posterior 00258 00259 rng<RNGTYPE>& generator_; // Refernce to random number generator 00260 00261 unsigned int n_; // The dimension of the distribution 00262 unsigned int thin_; // thinning parameter 00263 unsigned int iter_; // The current post-burnin iteration 00264 }; 00265 } // end namespace scythe 00266 #endif