Scythe-1.0.3
rtmvnorm.h
Go to the documentation of this file.
00001 /* 
00002  * Scythe Statistical Library Copyright (C) 2000-2002 Andrew D. Martin
00003  * and Kevin M. Quinn; 2002-present Andrew D. Martin, Kevin M. Quinn,
00004  * and Daniel Pemstein.  All Rights Reserved.
00005  *
00006  * This program is free software; you can redistribute it and/or
00007  * modify under the terms of the GNU General Public License as
00008  * published by Free Software Foundation; either version 2 of the
00009  * License, or (at your option) any later version.  See the text files
00010  * COPYING and LICENSE, distributed with this source code, for further
00011  * information.
00012  * --------------------------------------------------------------------
00013  *  scythestat/rng/rtmvnorm.h
00014  *
00015  */
00016 
00028 #ifndef SCYTHE_RTMVNORM_H
00029 #define SCYTHE_RTMVNORM_H
00030 
00031 #include <iostream>
00032 #include <cmath>
00033 
00034 #ifdef SCYTHE_COMPILE_DIRECT
00035 #include "matrix.h"
00036 #include "rng.h"
00037 #include "error.h"
00038 #include "algorithm.h"
00039 #include "ide.h"
00040 #else
00041 #include "scythestat/matrix.h"
00042 #include "scythestat/rng.h"
00043 #include "scythestat/error.h"
00044 #include "scythestat/algorithm.h"
00045 #include "scythestat/ide.h"
00046 #endif
00047 namespace scythe
00048 {
00049   /* Truncated Multivariate Normal Distribution by Gibbs sampling
00050    * (Geweke 1991).  This is a functor that allows one to
00051    * initialize---and optionally burn in---a sampler for a given
00052    * truncated multivariate normal distribution on construction
00053    * and then make (optionally thinned) draws with calls to the ()
00054    * operator.
00055    *
00056    */
00067   template <class RNGTYPE>
00068   class rtmvnorm {
00069     public:
00070 
00120       template <matrix_order PO1, matrix_style PS1, matrix_order PO2,
00121                 matrix_style PS2, matrix_order PO3, matrix_style PS3,
00122                 matrix_order PO4, matrix_style PS5, matrix_order PO5,
00123                 matrix_style PS4>
00124       rtmvnorm (const Matrix<double, PO1,PS1>& mu,
00125                 const Matrix<double, PO2, PS2>& sigma,
00126                 const Matrix<double, PO3, PS3>& D,
00127                 const Matrix<double, PO4, PS4>& a,
00128                 const Matrix<double, PO5, PS5>& b, rng<RNGTYPE>& generator, 
00129                 unsigned int burnin = 0, unsigned int thin = 1,
00130                 bool preinvertedD = false) 
00131       : mu_ (mu), C_ (mu.rows(), mu.rows(), false), 
00132         h_ (mu.rows(), 1, false), z_ (mu.rows(), 1, true, 0), 
00133         generator_ (generator), n_ (mu.rows()), thin_ (thin), iter_ (0)
00134       {
00135         SCYTHE_CHECK_10(thin == 0, scythe_invalid_arg,
00136             "thin must be >= 1");
00137         SCYTHE_CHECK_10(! mu.isColVector(), scythe_dimension_error,
00138             "mu not column vector");
00139         SCYTHE_CHECK_10(!  sigma.isSquare(), scythe_dimension_error,
00140             "sigma not square");
00141         SCYTHE_CHECK_10(!  D.isSquare(), scythe_dimension_error,
00142             "D not square");
00143         SCYTHE_CHECK_10(!  a.isColVector(), scythe_dimension_error,
00144             "a not column vector");
00145         SCYTHE_CHECK_10(!  b.isColVector(), scythe_dimension_error,
00146             "b not column vector");
00147         SCYTHE_CHECK_10(sigma.rows() != n_ || D.rows() != n_ || 
00148             a.rows() != n_ || b.rows() != n_, scythe_conformation_error,
00149             "mu, sigma, D, a, and b not conformable");
00150 
00151         // TODO will D * sigma * t(D) always be positive definite,
00152         // allowing us to use the faster invpd?
00153         if (preinvertedD)
00154           Dinv_ = D;
00155         else
00156           Dinv_ = inv(D);
00157         Matrix<> Tinv = inv(D * sigma * t(D));
00158         alpha_ = a - D * mu;
00159         beta_ =  b - D * mu;
00160 
00161         // Check truncation bounds
00162         if (SCYTHE_DEBUG > 0) {
00163           for (unsigned int i = 0; i < n_; ++i) {
00164             SCYTHE_CHECK(alpha_(i) >= beta_(i), scythe_invalid_arg,
00165                 "Truncation bound " << i 
00166                 << " not logically consistent");
00167           }
00168         }
00169 
00170         // Precompute some stuff (see Geweke 1991 pg 7).
00171         for (unsigned int i = 0; i < n_; ++i) {
00172           C_(i, _) = -(1 / Tinv(i, i)) % Tinv(i, _);
00173           C_(i, i) = 0; // not really clever but probably too clever
00174           h_(i) = std::sqrt(1 / Tinv(i, i));
00175           SCYTHE_CHECK_30(std::isnan(h_(i)), scythe_invalid_arg,
00176               "sigma is not positive definite");
00177         }
00178 
00179         // Do burnin
00180         for (unsigned int i = 0; i < burnin; ++i)
00181           sample ();
00182       }
00183 
00193       template <matrix_order O, matrix_style S>
00194       Matrix<double, O, S> operator() ()
00195       {
00196         do { sample (); } while (iter_ % thin_ != 0);
00197 
00198         return (mu_ + Dinv_ * z_);
00199       }
00200 
00207       Matrix<double,Col,Concrete> operator() ()
00208       {
00209         return operator()<Col, Concrete>();
00210       }
00211 
00212     protected:
00213       /* Does one step of the Gibbs sampler (see Geweke 1991 p 6) */
00214       void sample ()
00215       {
00216         double czsum;
00217         double above;
00218         double below;
00219         for (unsigned int i = 0; i < n_; ++i) {
00220 
00221           // Calculate sum_{j \ne i} c_{ij} z_{j}
00222           czsum = 0;
00223           for (unsigned int j = 0; j < n_; ++j) {
00224             if (i == j) continue;
00225             czsum += C_(i, j) * z_(j);
00226           }
00227 
00228           // Calc truncation of conditional univariate std normal
00229           below = (alpha_(i) - czsum) / h_(i);
00230           above = (beta_(i) - czsum) / h_(i);
00231           
00232           // Draw random variate z_i
00233           z_(i) = h_(i);
00234           if (above == std::numeric_limits<double>::infinity()){
00235             if (below == -std::numeric_limits<double>::infinity())
00236               z_(i) *= generator_.rnorm(0, 1); // untruncated
00237             else
00238               z_(i) *= generator_.rtbnorm_combo(0, 1, below);
00239           } else if (below == 
00240               -std::numeric_limits<double>::infinity())
00241             z_(i) *= generator_.rtanorm_combo(0, 1, above);
00242           else
00243             z_(i) *= generator_.rtnorm_combo(0, 1, below, above);
00244 
00245           z_(i) += czsum;
00246         }
00247 
00248         ++iter_;
00249       }
00250 
00251       /* Instance variables */
00252       // Various reused computation matrices with names from
00253       // Geweke 1991.
00254       Matrix<> mu_;   Matrix<> Dinv_; 
00255       Matrix<> C_; Matrix<> alpha_; Matrix<> beta_; Matrix<> h_; 
00256 
00257       Matrix<> z_; // The current draw of the posterior
00258 
00259       rng<RNGTYPE>& generator_; // Refernce to random number generator
00260 
00261       unsigned int n_;  // The dimension of the distribution
00262       unsigned int thin_; // thinning parameter
00263       unsigned int iter_; // The current post-burnin iteration
00264   };
00265 } // end namespace scythe
00266 #endif