Go to the documentation of this file.00001
00002
00003 #include "vpdfl_mixture.h"
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #include <vcl_cmath.h>
00018 #include <vcl_cstdlib.h>
00019 #include <vcl_string.h>
00020 #include <vsl/vsl_indent.h>
00021 #include <vsl/vsl_binary_loader.h>
00022 #include <vpdfl/vpdfl_mixture_sampler.h>
00023 #include <vcl_cassert.h>
00024 #include <vsl/vsl_vector_io.h>
00025 #include <vnl/vnl_c_vector.h>
00026 #include <vnl/vnl_math.h>
00027
00028
00029
00030
00031 vpdfl_mixture::vpdfl_mixture()
00032 {
00033 }
00034
00035 vpdfl_mixture::vpdfl_mixture(const vpdfl_mixture& m):
00036 vpdfl_pdf_base()
00037 {
00038 *this = m;
00039 }
00040
00041 vpdfl_mixture& vpdfl_mixture::operator=(const vpdfl_mixture& m)
00042 {
00043 if (this==&m) return *this;
00044
00045 delete_stuff();
00046
00047 vpdfl_pdf_base::operator=(m);
00048
00049 unsigned n = m.component_.size();
00050 component_.resize(n);
00051 for (unsigned i=0;i<n;++i)
00052 component_[i] = m.component_[i]->clone();
00053
00054 weight_ = m.weight_;
00055
00056 return *this;
00057 }
00058
00059
00060
00061 void vpdfl_mixture::delete_stuff()
00062 {
00063 unsigned n = component_.size();
00064 for (unsigned i=0;i<n;++i)
00065 delete component_[i];
00066 component_.resize(0);
00067 weight_.resize(0);
00068 }
00069
00070 vpdfl_mixture::~vpdfl_mixture()
00071 {
00072 delete_stuff();
00073 }
00074
00075
00076
00077 vpdfl_sampler_base* vpdfl_mixture::new_sampler() const
00078 {
00079 vpdfl_mixture_sampler* i = new vpdfl_mixture_sampler;
00080 i->set_model(*this);
00081
00082 return i;
00083 }
00084
00085
00086
00087 void vpdfl_mixture::init(const vpdfl_pdf_base& comp_type, unsigned n)
00088 {
00089 delete_stuff();
00090 component_.resize(n);
00091 weight_.resize(n);
00092 for (unsigned i=0;i<n;++i)
00093 {
00094 component_[i] = comp_type.clone();
00095 weight_[i] = 1.0/n;
00096 }
00097 }
00098
00099
00100
00101
00102 static inline void incXbyYv(vnl_vector<double> *X, const vnl_vector<double> &Y, double v)
00103 {
00104 assert(X->size() == Y.size());
00105 int i = ((int)X->size()) - 1;
00106 double * const pX=X->data_block();
00107 while (i >= 0)
00108 {
00109 pX[i] += Y[i] * v;
00110 i--;
00111 }
00112 }
00113
00114
00115 static inline void incXbyYplusXXv(vnl_vector<double> *X, const vnl_vector<double> &Y,
00116 const vnl_vector<double> &Z, double v)
00117 {
00118 assert(X->size() == Y.size());
00119 int i = ((int)X->size()) - 1;
00120 double * const pX=X->data_block();
00121 while (i >= 0)
00122 {
00123 pX[i] += (Y[i] + vnl_math_sqr(Z[i]))* v;
00124 i--;
00125 }
00126 }
00127
00128
00129
00130
00131 void vpdfl_mixture::set(const vcl_vector<vpdfl_pdf_base*> components, const vcl_vector<double> & weights)
00132 {
00133 unsigned n = components.size();
00134 assert (weights.size() == n);
00135
00136 component_.resize(n);
00137 for (unsigned int i=0; i<n; ++i)
00138 component_[i] = components[i]->clone();
00139
00140 weight_ = weights;
00141
00142
00143
00144 unsigned m = (n==0)?0:component_[0]->mean().size();
00145 vnl_vector<double> mean(m, 0.0);
00146 vnl_vector<double> var(m, 0.0);
00147
00148 for (unsigned i=0; i<n; ++i)
00149 {
00150 incXbyYv(&mean, component_[i]->mean(), weight_[i]);
00151 incXbyYplusXXv(&var, component_[i]->variance(),
00152 component_[i]->mean(), weight_[i]);
00153 }
00154
00155 for (unsigned i=0; i<m; ++i)
00156 var(i) -= vnl_math_sqr(mean(i));
00157
00158 set_mean(mean);
00159 set_variance(var);
00160 }
00161
00162
00163
00164 void vpdfl_mixture::add_component(const vpdfl_pdf_base& comp)
00165 {
00166 vcl_vector<vpdfl_pdf_base*> old_comps = component_;
00167 vcl_vector<double> old_wts = weight_;
00168 unsigned n = component_.size();
00169 assert(n == weight_.size());
00170
00171 component_.resize(n+1);
00172 weight_.resize(n+1);
00173
00174 for (unsigned i=0;i<n;++i)
00175 {
00176 component_[i] = old_comps[i];
00177 weight_[i] = old_wts[i];
00178 }
00179
00180 weight_[n] = 0.0;
00181 component_[n] = comp.clone();
00182 }
00183
00184
00185
00186 void vpdfl_mixture::clear()
00187 {
00188 delete_stuff();
00189 }
00190
00191
00192
00193
00194
00195
00196 bool vpdfl_mixture::is_valid_pdf() const
00197 {
00198 if (!vpdfl_pdf_base::is_valid_pdf()) return false;
00199 const unsigned n = n_components();
00200
00201 if (weight_.size() != n || component_.size() != n || n < 1) return false;
00202
00203 double sum =vnl_c_vector<double>::sum(&weight_[0], n);
00204 if (vcl_fabs(1.0 - sum) > 1e-10 ) return false;
00205
00206 for (unsigned i=0; i<n; ++i)
00207 {
00208 if (!components()[i]->is_valid_pdf()) return false;
00209 if (components()[i]->n_dims() != n_dims()) return false;
00210 }
00211 return true;
00212 }
00213
00214
00215
00216 void vpdfl_mixture::set_mean_and_variance(vnl_vector<double>&m, vnl_vector<double>&v)
00217 {
00218 assert(m.size() == v.size());
00219 set_mean(m);
00220 set_variance(v);
00221 }
00222
00223
00224
00225 vcl_string vpdfl_mixture::is_a() const
00226 {
00227 return vcl_string("vpdfl_mixture");
00228 }
00229
00230
00231
00232 bool vpdfl_mixture::is_class(vcl_string const& s) const
00233 {
00234 return vpdfl_pdf_base::is_class(s) || s==vpdfl_mixture::is_a();
00235 }
00236
00237
00238
00239 short vpdfl_mixture::version_no() const
00240 {
00241 return 1;
00242 }
00243
00244
00245
00246 vpdfl_pdf_base* vpdfl_mixture::clone() const
00247 {
00248 return new vpdfl_mixture(*this);
00249 }
00250
00251
00252
00253 void vpdfl_mixture::print_summary(vcl_ostream& os) const
00254 {
00255 os<<'\n'<<vsl_indent();
00256 vpdfl_pdf_base::print_summary(os);
00257 os<<'\n';
00258 for (unsigned int i=0;i<component_.size();++i)
00259 {
00260 os<<vsl_indent()<<"Component "<<i<<" : Wt: "<<weight_[i] <<'\n'
00261 <<vsl_indent()<<"PDF: " << component_[i]<<'\n';
00262 }
00263 }
00264
00265
00266
00267 void vpdfl_mixture::b_write(vsl_b_ostream& bfs) const
00268 {
00269 vsl_b_write(bfs, is_a());
00270 vsl_b_write(bfs, version_no());
00271 vpdfl_pdf_base::b_write(bfs);
00272 vsl_b_write(bfs, component_);
00273 vsl_b_write(bfs, weight_);
00274 }
00275
00276
00277
00278 void vpdfl_mixture::b_read(vsl_b_istream& bfs)
00279 {
00280 if (!bfs) return;
00281
00282 vcl_string name;
00283 vsl_b_read(bfs,name);
00284 if (name != is_a())
00285 {
00286 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture &)\n"
00287 << " Attempted to load object of type "
00288 << name <<" into object of type " << is_a() << '\n';
00289 bfs.is().clear(vcl_ios::badbit);
00290 return;
00291 }
00292
00293 delete_stuff();
00294
00295 short version;
00296 vsl_b_read(bfs,version);
00297 switch (version)
00298 {
00299 case (1):
00300 vpdfl_pdf_base::b_read(bfs);
00301 vsl_b_read(bfs, component_);
00302 vsl_b_read(bfs, weight_);
00303 break;
00304 default:
00305 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture &)\n"
00306 << " Unknown version number "<< version << '\n';
00307 bfs.is().clear(vcl_ios::badbit);
00308 return;
00309 }
00310 }
00311
00312
00313
00314 double vpdfl_mixture::operator()(const vnl_vector<double>& x) const
00315 {
00316 return vcl_exp(log_p(x));
00317 }
00318
00319
00320
00321 double vpdfl_mixture::log_p(const vnl_vector<double>& x) const
00322 {
00323 int n = n_components();
00324
00325 vnl_vector<double>& log_ps = ws_;
00326 log_ps.set_size(n);
00327
00328 double max_log_p = 0.0;
00329 for (int i=0;i<n;++i)
00330 {
00331 if (weight_[i]>0.0)
00332 {
00333 log_ps[i] = component_[i]->log_p(x);
00334 if (i==0 || log_ps[i]>max_log_p) max_log_p = log_ps[i];
00335 }
00336 }
00337
00338 double sum=0.0;
00339
00340 for (int i=0;i<n;i++)
00341 {
00342 if (weight_[i]>0.0)
00343 sum += weight_[i] * vcl_exp(log_ps[i]-max_log_p);
00344 }
00345
00346 return vcl_log(sum) + max_log_p;
00347 }
00348
00349
00350
00351 void vpdfl_mixture::gradient(vnl_vector<double>& g,
00352 const vnl_vector<double>& x,
00353 double& p) const
00354 {
00355 vnl_vector<double>& g1 = ws_;
00356
00357 double p1;
00358 component_[0]->gradient(g1,x,p1);
00359 g = g1*weight_[0];
00360 p = p1*weight_[0];
00361
00362 for (unsigned int i=1;i<n_components();i++)
00363 {
00364 component_[i]->gradient(g1,x,p1);
00365 g += g1*weight_[i];
00366 double p_comp = p1*weight_[i];
00367 p += p_comp;
00368 }
00369 }
00370
00371
00372
00373 unsigned vpdfl_mixture::nearest_comp(const vnl_vector<double>& x) const
00374 {
00375 assert(component_.size()>=1);
00376
00377 int n = n_components();
00378 if (n==1) return 0;
00379
00380 int best_i=0;
00381 double min_d2 = vnl_vector_ssd(x, component_[0]->mean());;
00382
00383 for (int i=1;i<n;i++)
00384 {
00385 double d2 = vnl_vector_ssd(x, component_[i]->mean());
00386 if (d2<min_d2)
00387 {
00388 best_i=i;
00389 min_d2=d2;
00390 }
00391 }
00392
00393 return best_i;
00394 }
00395
00396
00397
00398
00399
00400
00401
00402 void vpdfl_mixture::nearest_plausible(vnl_vector<double>& , double ) const
00403 {
00404 vcl_cerr << "ERROR: vpdfl_mixture::nearest_plausible NYI\n";
00405 vcl_abort();
00406 }
00407
00408