contrib/mul/vpdfl/vpdfl_mixture.cxx
Go to the documentation of this file.
00001 // This is mul/vpdfl/vpdfl_mixture.cxx
00002 // Copyright: (C) 1998 Victoria University of Manchester
00003 #include "vpdfl_mixture.h"
00004 //:
00005 // \file
00006 // \brief Implements a mixture model (a set of individual pdfs + weights)
00007 // \author Tim Cootes
00008 // \date 21-July-98
00009 //
00010 // \verbatim
00011 //  Modifications
00012 //    IMS   Converted to VXL 12 May 2000
00013 // \endverbatim
00014 
00015 //=======================================================================
00016 
00017 #include <vcl_cmath.h>
00018 #include <vcl_cstdlib.h>
00019 #include <vcl_string.h>
00020 #include <vsl/vsl_indent.h>
00021 #include <vsl/vsl_binary_loader.h>
00022 #include <vpdfl/vpdfl_mixture_sampler.h>
00023 #include <vcl_cassert.h>
00024 #include <vsl/vsl_vector_io.h>
00025 #include <vnl/vnl_c_vector.h>
00026 #include <vnl/vnl_math.h>
00027 
00028 //=======================================================================
00029 
00030 
00031 vpdfl_mixture::vpdfl_mixture()
00032 {
00033 }
00034 
00035 vpdfl_mixture::vpdfl_mixture(const vpdfl_mixture& m):
00036   vpdfl_pdf_base()
00037 {
00038   *this = m;
00039 }
00040 
00041 vpdfl_mixture& vpdfl_mixture::operator=(const vpdfl_mixture& m)
00042 {
00043   if (this==&m) return *this;
00044 
00045   delete_stuff();
00046 
00047   vpdfl_pdf_base::operator=(m);
00048 
00049   unsigned n = m.component_.size();
00050   component_.resize(n);
00051   for (unsigned i=0;i<n;++i)
00052     component_[i] = m.component_[i]->clone();
00053 
00054   weight_ = m.weight_;
00055 
00056   return *this;
00057 }
00058 
00059 //=======================================================================
00060 
00061 void vpdfl_mixture::delete_stuff()
00062 {
00063   unsigned n = component_.size();
00064   for (unsigned i=0;i<n;++i)
00065     delete component_[i];
00066   component_.resize(0);
00067   weight_.resize(0);
00068 }
00069 
00070 vpdfl_mixture::~vpdfl_mixture()
00071 {
00072   delete_stuff();
00073 }
00074 
00075 //=======================================================================
00076 //: Return instance of this PDF
00077 vpdfl_sampler_base* vpdfl_mixture::new_sampler() const
00078 {
00079   vpdfl_mixture_sampler* i = new vpdfl_mixture_sampler;
00080   i->set_model(*this);
00081 
00082   return i;
00083 }
00084 
00085 //=======================================================================
00086 //: Initialise to use n components of type comp_type
00087 void vpdfl_mixture::init(const vpdfl_pdf_base& comp_type, unsigned n)
00088 {
00089   delete_stuff();
00090   component_.resize(n);
00091   weight_.resize(n);
00092   for (unsigned i=0;i<n;++i)
00093   {
00094     component_[i] = comp_type.clone();
00095     weight_[i] = 1.0/n;
00096   }
00097 }
00098 
00099 
00100 //=======================================================================
00101 //: Add Y*v to X
00102 static inline void incXbyYv(vnl_vector<double> *X, const vnl_vector<double> &Y, double v)
00103 {
00104   assert(X->size() == Y.size());
00105   int i = ((int)X->size()) - 1;
00106   double * const pX=X->data_block();
00107   while (i >= 0)
00108   {
00109     pX[i] += Y[i] * v;
00110     i--;
00111   }
00112 }
00113 
00114 //: Add (Y + Z.*Z)*v to X
00115 static inline void incXbyYplusXXv(vnl_vector<double> *X, const vnl_vector<double> &Y,
00116                                   const vnl_vector<double> &Z, double v)
00117 {
00118   assert(X->size() == Y.size());
00119   int i = ((int)X->size()) - 1;
00120   double * const pX=X->data_block();
00121   while (i >= 0)
00122   {
00123     pX[i] += (Y[i] + vnl_math_sqr(Z[i]))* v;
00124     i--;
00125   }
00126 }
00127 
00128 //=======================================================================
00129 //: Set the contents of the mixture model.
00130 // Clones are taken of all the data, and the class will be responsible for their deletion.
00131 void vpdfl_mixture::set(const vcl_vector<vpdfl_pdf_base*> components, const vcl_vector<double> & weights)
00132 {
00133   unsigned n = components.size();
00134   assert (weights.size() == n);
00135 
00136   component_.resize(n);
00137   for (unsigned int i=0; i<n; ++i)
00138     component_[i] = components[i]->clone();
00139 
00140   weight_ = weights;
00141 
00142   // Calculate the mixtures overall mean and variance.
00143 
00144   unsigned  m = (n==0)?0:component_[0]->mean().size();
00145   vnl_vector<double> mean(m, 0.0);
00146   vnl_vector<double> var(m, 0.0);
00147 
00148   for (unsigned i=0; i<n; ++i)
00149   {
00150     incXbyYv(&mean, component_[i]->mean(), weight_[i]);
00151     incXbyYplusXXv(&var, component_[i]->variance(),
00152       component_[i]->mean(), weight_[i]);
00153   }
00154 
00155   for (unsigned i=0; i<m; ++i)
00156     var(i) -= vnl_math_sqr(mean(i));
00157 
00158   set_mean(mean);
00159   set_variance(var);
00160 }
00161 
00162 //=======================================================================
00163 
00164 void vpdfl_mixture::add_component(const vpdfl_pdf_base& comp)
00165 {
00166   vcl_vector<vpdfl_pdf_base*> old_comps = component_;
00167   vcl_vector<double> old_wts = weight_;
00168   unsigned n = component_.size();
00169   assert(n == weight_.size());
00170 
00171   component_.resize(n+1);
00172   weight_.resize(n+1);
00173 
00174   for (unsigned i=0;i<n;++i)
00175   {
00176     component_[i] = old_comps[i];
00177     weight_[i] = old_wts[i];
00178   }
00179 
00180   weight_[n] = 0.0;
00181   component_[n] = comp.clone();
00182 }
00183 
00184 //=======================================================================
00185 
00186 void vpdfl_mixture::clear()
00187 {
00188   delete_stuff();
00189 }
00190 
00191 //=======================================================================
00192 
00193 //: Return true if the object represents a valid PDF.
00194 // This will return false, if n_dims() is 0, for example just ofter
00195 // default construction.
00196 bool vpdfl_mixture::is_valid_pdf() const
00197 {
00198   if (!vpdfl_pdf_base::is_valid_pdf()) return false;
00199   const unsigned n = n_components();
00200     // the number of components should be consistent
00201   if (weight_.size() != n || component_.size() != n || n < 1) return false;
00202     // weights should sum to 1.
00203   double sum =vnl_c_vector<double>::sum(&weight_[0]/*.begin()*/, n);
00204   if (vcl_fabs(1.0 - sum) > 1e-10 ) return false;
00205     // the number of dimensions should be consistent
00206   for (unsigned i=0; i<n; ++i)
00207   {
00208     if (!components()[i]->is_valid_pdf()) return false;
00209     if (components()[i]->n_dims() != n_dims()) return false;
00210   }
00211   return true;
00212 }
00213 
00214 //=======================================================================
00215 //: Set the whole pdf mean and variance values.
00216 void vpdfl_mixture::set_mean_and_variance(vnl_vector<double>&m, vnl_vector<double>&v)
00217 {
00218   assert(m.size() == v.size());
00219   set_mean(m);
00220   set_variance(v);
00221 }
00222 
00223 //=======================================================================
00224 
00225 vcl_string vpdfl_mixture::is_a() const
00226 {
00227   return vcl_string("vpdfl_mixture");
00228 }
00229 
00230 //=======================================================================
00231 
00232 bool vpdfl_mixture::is_class(vcl_string const& s) const
00233 {
00234   return vpdfl_pdf_base::is_class(s) || s==vpdfl_mixture::is_a();
00235 }
00236 
00237 //=======================================================================
00238 
00239 short vpdfl_mixture::version_no() const
00240 {
00241   return 1;
00242 }
00243 
00244 //=======================================================================
00245 
00246 vpdfl_pdf_base* vpdfl_mixture::clone() const
00247 {
00248   return new vpdfl_mixture(*this);
00249 }
00250 
00251 //=======================================================================
00252 
00253 void vpdfl_mixture::print_summary(vcl_ostream& os) const
00254 {
00255   os<<'\n'<<vsl_indent();
00256   vpdfl_pdf_base::print_summary(os);
00257   os<<'\n';
00258   for (unsigned int i=0;i<component_.size();++i)
00259   {
00260     os<<vsl_indent()<<"Component "<<i<<" :  Wt: "<<weight_[i] <<'\n'
00261       <<vsl_indent()<<"PDF: " << component_[i]<<'\n';
00262   }
00263 }
00264 
00265 //=======================================================================
00266 
00267 void vpdfl_mixture::b_write(vsl_b_ostream& bfs) const
00268 {
00269   vsl_b_write(bfs, is_a());
00270   vsl_b_write(bfs, version_no());
00271   vpdfl_pdf_base::b_write(bfs);
00272   vsl_b_write(bfs, component_);
00273   vsl_b_write(bfs, weight_);
00274 }
00275 
00276 //=======================================================================
00277 
00278 void vpdfl_mixture::b_read(vsl_b_istream& bfs)
00279 {
00280   if (!bfs) return;
00281 
00282   vcl_string name;
00283   vsl_b_read(bfs,name);
00284   if (name != is_a())
00285   {
00286     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture &)\n"
00287              << "           Attempted to load object of type "
00288              << name <<" into object of type " << is_a() << '\n';
00289     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00290     return;
00291   }
00292 
00293   delete_stuff();
00294 
00295   short version;
00296   vsl_b_read(bfs,version);
00297   switch (version)
00298   {
00299     case (1):
00300       vpdfl_pdf_base::b_read(bfs);
00301       vsl_b_read(bfs, component_);
00302       vsl_b_read(bfs, weight_);
00303       break;
00304     default:
00305       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture &)\n"
00306                << "           Unknown version number "<< version << '\n';
00307       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00308       return;
00309   }
00310 }
00311 
00312 //=======================================================================
00313 
00314 double vpdfl_mixture::operator()(const vnl_vector<double>& x) const
00315 {
00316   return vcl_exp(log_p(x));
00317 }
00318 
00319 //=======================================================================
00320 
00321 double vpdfl_mixture::log_p(const vnl_vector<double>& x) const
00322 {
00323   int n = n_components();
00324 
00325   vnl_vector<double>& log_ps = ws_;
00326   log_ps.set_size(n);
00327 
00328   double max_log_p = 0.0; // initialise just to make the compiler happy
00329   for (int i=0;i<n;++i)
00330   {
00331     if (weight_[i]>0.0)
00332     {
00333       log_ps[i] = component_[i]->log_p(x);
00334       if (i==0 || log_ps[i]>max_log_p) max_log_p = log_ps[i];
00335     }
00336   }
00337 
00338   double sum=0.0;
00339 
00340   for (int i=0;i<n;i++)
00341   {
00342     if (weight_[i]>0.0)
00343     sum += weight_[i] * vcl_exp(log_ps[i]-max_log_p);
00344   }
00345 
00346   return vcl_log(sum) + max_log_p;
00347 }
00348 
00349 //=======================================================================
00350 
00351 void vpdfl_mixture::gradient(vnl_vector<double>& g,
00352                                      const vnl_vector<double>& x,
00353                                      double& p) const
00354 {
00355   vnl_vector<double>& g1 = ws_;
00356 
00357   double p1;
00358   component_[0]->gradient(g1,x,p1);
00359   g = g1*weight_[0];
00360   p = p1*weight_[0];
00361 
00362   for (unsigned int i=1;i<n_components();i++)
00363   {
00364     component_[i]->gradient(g1,x,p1);
00365     g += g1*weight_[i];
00366     double p_comp = p1*weight_[i];
00367     p += p_comp;
00368   }
00369 }
00370 
00371 //=======================================================================
00372 
00373 unsigned vpdfl_mixture::nearest_comp(const vnl_vector<double>& x) const
00374 {
00375   assert(component_.size()>=1);
00376 
00377   int n = n_components();
00378   if (n==1) return 0;
00379 
00380   int best_i=0;
00381   double min_d2 = vnl_vector_ssd(x, component_[0]->mean());;
00382 
00383   for (int i=1;i<n;i++)
00384   {
00385     double d2 = vnl_vector_ssd(x, component_[i]->mean());
00386     if (d2<min_d2)
00387     {
00388       best_i=i;
00389       min_d2=d2;
00390     }
00391   }
00392 
00393   return best_i;
00394 }
00395 
00396 //=======================================================================
00397 
00398 //: Compute nearest point to x which has a density above a threshold
00399 //  If log_p(x)>log_p_min then x unchanged.  Otherwise x is moved
00400 //  (typically up the gradient) until log_p(x)>=log_p_min.
00401 // \param x This may be modified to the nearest plausible position.
00402 void vpdfl_mixture::nearest_plausible(vnl_vector<double>& /*x*/, double /*log_p_min*/) const
00403 {
00404   vcl_cerr << "ERROR: vpdfl_mixture::nearest_plausible NYI\n";
00405   vcl_abort();
00406 }
00407 
00408 //==================< end of file: vpdfl_mixture.cxx >====================