00001
00002
00003
00004
00005
00006
00007 #include "vpdfl_mixture_builder.h"
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <vcl_sstream.h>
00020 #include <vcl_cassert.h>
00021 #include <vcl_cmath.h>
00022 #include <vcl_cstdlib.h>
00023 #include <vsl/vsl_indent.h>
00024 #include <vsl/vsl_vector_io.h>
00025 #include <vsl/vsl_binary_loader.h>
00026 #include <vpdfl/vpdfl_mixture.h>
00027 #include <mbl/mbl_data_wrapper.h>
00028 #include <mbl/mbl_data_array_wrapper.h>
00029 #include <vnl/vnl_math.h>
00030
00031 #include <mbl/mbl_parse_block.h>
00032 #include <mbl/mbl_read_props.h>
00033 #include <vul/vul_string.h>
00034 #include <mbl/mbl_exception.h>
00035
00036
00037 const double min_wt = 1e-8;
00038
00039
00040 void vpdfl_mixture_builder::init()
00041 {
00042 min_var_ = 1.0e-6;
00043 max_its_ = 10;
00044 weights_fixed_ = false;
00045 initial_means_.clear();
00046 }
00047
00048
00049
00050 vpdfl_mixture_builder::vpdfl_mixture_builder()
00051 {
00052 init();
00053 }
00054
00055
00056
00057 vpdfl_mixture_builder::vpdfl_mixture_builder(const vpdfl_mixture_builder& b):
00058 vpdfl_builder_base()
00059 {
00060 init();
00061 *this = b;
00062 }
00063
00064
00065
00066 vpdfl_mixture_builder& vpdfl_mixture_builder::operator=(const vpdfl_mixture_builder& b)
00067 {
00068 if (&b==this) return *this;
00069
00070 delete_stuff();
00071
00072 unsigned int n = b.builder_.size();
00073 builder_.resize(n);
00074 for (unsigned int i=0;i<n;++i)
00075 builder_[i] = b.builder_[i]->clone();
00076
00077 min_var_ = b.min_var_;
00078 max_its_ = b.max_its_;
00079 weights_fixed_ = b.weights_fixed_;
00080 initial_means_ = b.initial_means_;
00081
00082 return *this;
00083 }
00084
00085
00086
00087 void vpdfl_mixture_builder::delete_stuff()
00088 {
00089 unsigned int n = builder_.size();
00090 for (unsigned int i=0;i<n;++i)
00091 delete builder_[i];
00092 builder_.resize(0);
00093 initial_means_.clear();
00094 }
00095
00096 vpdfl_mixture_builder::~vpdfl_mixture_builder()
00097 {
00098 delete_stuff();
00099 }
00100
00101
00102
00103
00104
00105 void vpdfl_mixture_builder::init(const vpdfl_builder_base& builder, int n)
00106 {
00107 delete_stuff();
00108 builder_.resize(n);
00109 for (int i=0;i<n;++i)
00110 builder_[i] = builder.clone();
00111 }
00112
00113
00114
00115
00116 void vpdfl_mixture_builder::set_max_iterations(int n)
00117 {
00118 max_its_ = n;
00119 }
00120
00121 void vpdfl_mixture_builder::set_weights_fixed(bool b)
00122 {
00123 weights_fixed_ = b;
00124 }
00125
00126
00127
00128
00129 vpdfl_pdf_base* vpdfl_mixture_builder::new_model() const
00130 {
00131 return new vpdfl_mixture;
00132 }
00133
00134
00135
00136
00137 void vpdfl_mixture_builder::set_min_var(double min_var)
00138 {
00139 min_var_ = min_var;
00140 }
00141
00142
00143
00144
00145 double vpdfl_mixture_builder::min_var() const
00146 {
00147 return min_var_;
00148 }
00149
00150
00151
00152
00153 void vpdfl_mixture_builder::build(vpdfl_pdf_base& ,
00154 const vnl_vector<double>& ) const
00155 {
00156 vcl_cerr<<"vpdfl_mixture_builder::build(model,mean) Not yet implemented.\n";
00157 vcl_abort();
00158 }
00159
00160
00161
00162
00163 void vpdfl_mixture_builder::build(vpdfl_pdf_base& model,
00164 mbl_data_wrapper<vnl_vector<double> >& data) const
00165 {
00166 vcl_vector<double> wts(int(data.size()), 1.0);
00167 weighted_build(model,data,wts);
00168 }
00169
00170
00171
00172
00173 void vpdfl_mixture_builder::weighted_build(vpdfl_pdf_base& base_model,
00174 mbl_data_wrapper<vnl_vector<double> >& data,
00175 const vcl_vector<double>& wts) const
00176 {
00177 assert(base_model.is_class("vpdfl_mixture"));
00178 vpdfl_mixture& model = static_cast<vpdfl_mixture&>( base_model);
00179
00180 unsigned int n = builder_.size();
00181
00182 bool model_setup = (model.n_components()==n);
00183
00184 if (!model_setup)
00185 {
00186
00187 model.clear();
00188 model.components().resize(n);
00189 model.weights().resize(n);
00190 for (unsigned int i=0;i<n;++i)
00191 {
00192 builder_[i]->set_min_var(min_var_);
00193 model.components()[i] = builder_[i]->new_model();
00194 model.weights()[i] = 1.0/n;
00195 }
00196 }
00197
00198
00199 const vnl_vector<double>* data_ptr;
00200 vcl_vector<vnl_vector<double> > data_array;
00201
00202 {
00203 unsigned int n=data.size();
00204 data.reset();
00205 data_array.resize(n);
00206 for (unsigned int i=0;i<n;++i)
00207 {
00208 data_array[i] = data.current();
00209 data.next();
00210 }
00211
00212 data_ptr = &data_array[0];
00213 }
00214
00215 if (!model_setup || !initial_means_.empty())
00216 initialise(model,data_ptr,wts);
00217
00218 vcl_vector<vnl_vector<double> > probs;
00219
00220 int n_its = 0;
00221 double max_move = 1e-6;
00222 double move = max_move+1;
00223 while (move>max_move && n_its<max_its_)
00224 {
00225 e_step(model,probs,data_ptr,wts);
00226 move = m_step(model,probs,data_ptr,wts);
00227 n_its++;
00228 }
00229 calc_mean_and_variance(model);
00230 assert(model.is_valid_pdf());
00231 }
00232
00233 static void UpdateRange(vnl_vector<double>& min_vec, vnl_vector<double>& max_vec, const vnl_vector<double>& vec)
00234 {
00235 unsigned int n=vec.size();
00236 for (unsigned int i=0;i<n;++i)
00237 {
00238 if (vec(i)<min_vec(i))
00239 min_vec(i)=vec(i);
00240 else
00241 if (vec(i)>max_vec(i))
00242 max_vec(i)=vec(i);
00243 }
00244 }
00245
00246
00247 void vpdfl_mixture_builder::initialise_given_means(vpdfl_mixture& model,
00248 const vnl_vector<double>* data,
00249 const vcl_vector<vnl_vector<double> >& mean,
00250 const vcl_vector<double>& wts) const
00251 {
00252 const unsigned int n_comp = builder_.size();
00253 const unsigned int n_samples = wts.size();
00254
00255
00256 vnl_vector<double> min_v(mean[0]);
00257 vnl_vector<double> max_v(min_v);
00258 for (unsigned int i=1;i<n_comp;++i)
00259 UpdateRange(min_v,max_v,mean[i]);
00260
00261 double mean_sep = vnl_vector_ssd(max_v,min_v)/n_samples;
00262 if (mean_sep<=1e-6) mean_sep = 1e-6;
00263
00264
00265 vcl_vector<double> wts_i(n_samples);
00266
00267 mbl_data_array_wrapper<vnl_vector<double> > data_array(data,n_samples);
00268
00269 for (unsigned int i=0;i<n_comp;++i)
00270 {
00271
00272 double w_sum = 0.0;
00273 for (unsigned int j=0;j<n_samples;++j)
00274 {
00275 wts_i[j] = wts[j]*mean_sep/(mean_sep+ vnl_vector_ssd(data[j], mean[i]));
00276 w_sum+=wts_i[j];
00277 }
00278
00279
00280 double f = n_samples/(n_comp*w_sum);
00281 for (unsigned int j=0;j<n_samples;++j)
00282 wts_i[j]*=f;
00283
00284
00285 builder_[i]->weighted_build(*(model.components()[i]),data_array,wts_i);
00286 }
00287 }
00288
00289
00290
00291 void vpdfl_mixture_builder::initialise_diagonal(vpdfl_mixture& model,
00292 const vnl_vector<double>* data,
00293 const vcl_vector<double>& wts) const
00294 {
00295
00296 const unsigned int n_comp = builder_.size();
00297 const unsigned int n_samples = wts.size();
00298
00299
00300 vnl_vector<double> min_v(data[0]);
00301 vnl_vector<double> max_v(min_v);
00302 for (unsigned int i=1;i<n_samples;++i)
00303 UpdateRange(min_v,max_v,data[i]);
00304
00305 #if 0 // unused variable
00306 double mean_sep = vnl_vector_ssd(max_v,min_v)/n_samples;
00307 #endif
00308
00309
00310 vcl_vector<vnl_vector<double> > mean(n_comp);
00311 for (unsigned int i=0;i<n_comp;++i)
00312 {
00313 double f = (i+1.0)/(n_comp+1);
00314 mean[i] = (1-f)*min_v + f*max_v;
00315 }
00316
00317 initialise_given_means(model,data,mean,wts);
00318 }
00319
00320
00321
00322 void vpdfl_mixture_builder::initialise_to_regular_samples(vpdfl_mixture& model,
00323 const vnl_vector<double>* data,
00324 const vcl_vector<double>& wts) const
00325 {
00326
00327 const unsigned int n_comp = builder_.size();
00328 const unsigned int n_samples = wts.size();
00329
00330 double f = double(n_samples)/n_comp;
00331
00332
00333 vcl_vector<vnl_vector<double> > mean(n_comp);
00334 for (unsigned int i=0;i<n_comp;++i)
00335 {
00336 unsigned int j = vnl_math_rnd((i+0.5)*f);
00337 if (j>=n_samples) j=n_samples-1;
00338 mean[i] = data[j];
00339 }
00340
00341 initialise_given_means(model,data,mean,wts);
00342 }
00343
00344 void vpdfl_mixture_builder::initialise(vpdfl_mixture& model,
00345 const vnl_vector<double>* data,
00346 const vcl_vector<double>& wts) const
00347 {
00348
00349 if (!initial_means_.empty() )
00350 {
00351 initialise_given_means(model,data,initial_means_,wts);
00352 }
00353 else
00354 {
00355 initialise_to_regular_samples(model,data,wts);
00356 }
00357 }
00358
00359 void vpdfl_mixture_builder::preset_initial_means(const vcl_vector<vnl_vector<double> >& component_means)
00360 {
00361 initial_means_ = component_means;
00362 }
00363
00364
00365 void vpdfl_mixture_builder::e_step(vpdfl_mixture& model,
00366 vcl_vector<vnl_vector<double> >& probs,
00367 const vnl_vector<double>* data,
00368 const vcl_vector<double>& wts) const
00369 {
00370 const unsigned int n_comp = builder_.size();
00371 const unsigned int n_egs = wts.size();
00372 const vcl_vector<double>& m_wts = model.weights();
00373
00374 if (probs.size()!=n_comp) probs.resize(n_comp);
00375
00376
00377
00378 for (unsigned int i=0;i<n_comp;++i)
00379 {
00380 if (probs[i].size()!=n_egs) probs[i].set_size(n_egs);
00381
00382
00383
00384 if (m_wts[i]<=0) continue;
00385
00386 double *p_data = probs[i].begin();
00387
00388 double log_wt_i = vcl_log(m_wts[i]);
00389
00390 for (unsigned int j=0;j<n_egs;++j)
00391 {
00392 p_data[j] = log_wt_i+model.components()[i]->log_p(data[j]);
00393 }
00394 }
00395
00396
00397
00398 for (unsigned int j=0;j<n_egs;++j)
00399 {
00400
00401 double max_log_p=0;
00402 for (unsigned int i=0;i<n_comp;++i)
00403 {
00404 if (m_wts[i]<=0) continue;
00405 if (i==0 || probs[i](j)>max_log_p) max_log_p = probs[i](j);
00406 }
00407
00408
00409 double sum = 0.0;
00410 for (unsigned int i=0;i<n_comp;++i)
00411 {
00412 if (m_wts[i]<=0) continue;
00413 double p = vcl_exp(probs[i](j)-max_log_p);
00414 probs[i](j) = p;
00415 sum+=p;
00416 }
00417
00418
00419 if (sum>0.0)
00420 for (unsigned int i=0;i<n_comp;++i)
00421 probs[i](j)/=sum;
00422
00423 if (sum<=0)
00424 vcl_cerr<<"vpdfl_mixture_builder::e_step() Zero sum for probs!\n";
00425 }
00426 }
00427
00428
00429
00430 double vpdfl_mixture_builder::m_step(vpdfl_mixture& model,
00431 const vcl_vector<vnl_vector<double> >& probs,
00432 const vnl_vector<double>* data,
00433 const vcl_vector<double>& wts) const
00434 {
00435 const unsigned int n_comp = builder_.size();
00436 const unsigned int n_egs = wts.size();
00437 vcl_vector<double> wts_i(n_egs);
00438
00439 mbl_data_array_wrapper<vnl_vector<double> > data_array(data,n_egs);
00440
00441 double move = 0.0;
00442 vnl_vector<double> old_mean;
00443
00444 if (!weights_fixed_)
00445 {
00446 double w_sum = 0.0;
00447
00448 for (unsigned int i=0;i<n_comp;++i)
00449 {
00450 model.weights()[i]=probs[i].mean();
00451
00452
00453 if (model.weights()[i]<min_wt) model.weights()[i]=0.0;
00454
00455 w_sum += model.weights()[i];
00456 }
00457
00458
00459 for (unsigned int i=0;i<n_comp;++i)
00460 model.weights()[i]/=w_sum;
00461 }
00462
00463 for (unsigned int i=0;i<n_comp;++i)
00464 {
00465
00466
00467 if (model.weights()[i]<=0.0) continue;
00468
00469
00470 const double* p = probs[i].begin();
00471 double w_sum = 0.0;
00472 for (unsigned int j=0;j<n_egs;++j)
00473 {
00474 wts_i[j] = wts[j]*p[j];
00475 w_sum += wts_i[j];
00476 }
00477
00478 if (w_sum<=0.0)
00479 vcl_cerr<<"m_step: Dubious weights. sum="<<w_sum<<'\n';
00480
00481 old_mean = model.components()[i]->mean();
00482 builder_[i]->weighted_build(*(model.components()[i]), data_array, wts_i);
00483
00484 move += vnl_vector_ssd(old_mean, model.components()[i]->mean());
00485 }
00486
00487
00488 return move;
00489 }
00490
00491
00492
00493
00494 static inline void incXbyYv(vnl_vector<double> *X, const vnl_vector<double> &Y, double v)
00495 {
00496 assert(X->size() == Y.size());
00497 int i = ((int)X->size()) - 1;
00498 double * const pX=X->data_block();
00499 while (i >= 0)
00500 {
00501 pX[i] += Y[i] * v;
00502 i--;
00503 }
00504 }
00505
00506
00507 static inline void incXbyYplusXXv(vnl_vector<double> *X, const vnl_vector<double> &Y,
00508 const vnl_vector<double> &Z, double v)
00509 {
00510 assert(X->size() == Y.size());
00511 int i = ((int)X->size()) - 1;
00512 double * const pX=X->data_block();
00513 while (i >= 0)
00514 {
00515 pX[i] += (Y[i] + vnl_math_sqr(Z[i]))* v;
00516 i--;
00517 }
00518 }
00519
00520
00521
00522 void vpdfl_mixture_builder::calc_mean_and_variance(vpdfl_mixture& model)
00523 {
00524 unsigned int n = model.component(0).mean().size();
00525 vnl_vector<double> mean(n, 0.0);
00526 vnl_vector<double> var(n, 0.0);
00527
00528 for (unsigned int i=0; i<model.n_components(); ++i)
00529 {
00530 incXbyYv(&mean, model.component(i).mean(), model.weight(i));
00531 incXbyYplusXXv(&var, model.component(i).variance(),
00532 model.component(i).mean(), model.weight(i));
00533 }
00534
00535 for (unsigned int i=0; i<n; ++i)
00536 var(i) -= vnl_math_sqr(mean(i));
00537
00538 model.set_mean_and_variance(mean, var);
00539 }
00540
00541
00542
00543 vcl_string vpdfl_mixture_builder::is_a() const
00544 {
00545 return vcl_string("vpdfl_mixture_builder");
00546 }
00547
00548
00549
00550 bool vpdfl_mixture_builder::is_class(vcl_string const& s) const
00551 {
00552 return vpdfl_builder_base::is_class(s) || s==vpdfl_mixture_builder::is_a();
00553 }
00554
00555
00556
00557 short vpdfl_mixture_builder::version_no() const
00558 {
00559 return 1;
00560 }
00561
00562
00563
00564 vpdfl_builder_base* vpdfl_mixture_builder::clone() const
00565 {
00566 return new vpdfl_mixture_builder(*this);
00567 }
00568
00569
00570
00571 void vpdfl_mixture_builder::print_summary(vcl_ostream& os) const
00572 {
00573 if (weights_fixed_) os<<vsl_indent()<<"Weights fixed"<<'\n';
00574 else os<<vsl_indent()<<"Weights may vary"<<'\n';
00575 os<<vsl_indent()<<"Max iterations: "<<max_its_<<'\n';
00576 for (unsigned int i=0;i<builder_.size();++i)
00577 {
00578 os<<vsl_indent()<<"Builder "<<i<<": ";
00579 vsl_print_summary(os, builder_[i]); os << '\n';
00580 }
00581 }
00582
00583
00584
00585 void vpdfl_mixture_builder::b_write(vsl_b_ostream& bfs) const
00586 {
00587 vsl_b_write(bfs,is_a());
00588 vsl_b_write(bfs,version_no());
00589 vsl_b_write(bfs,builder_);
00590 vsl_b_write(bfs,max_its_);
00591 vsl_b_write(bfs,weights_fixed_);
00592 }
00593
00594
00595
00596 void vpdfl_mixture_builder::b_read(vsl_b_istream& bfs)
00597 {
00598 if (!bfs) return;
00599
00600 vcl_string name;
00601 vsl_b_read(bfs,name);
00602 if (name != is_a())
00603 {
00604 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture_builder &)\n"
00605 << " Attempted to load object of type "
00606 << name <<" into object of type " << is_a() << '\n';
00607 bfs.is().clear(vcl_ios::badbit);
00608 return;
00609 }
00610
00611 delete_stuff();
00612
00613 short version;
00614 vsl_b_read(bfs,version);
00615 switch (version)
00616 {
00617 case (1):
00618 vsl_b_read(bfs,builder_);
00619 vsl_b_read(bfs,max_its_);
00620 vsl_b_read(bfs,weights_fixed_);
00621 break;
00622 default:
00623 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, vpdfl_mixture_builder &)\n"
00624 << " Unknown version number "<< version << '\n';
00625 bfs.is().clear(vcl_ios::badbit);
00626 return;
00627 }
00628 }
00629
00630
00631
00632
00633
00634
00635
00636
00637
00638
00639
00640 void vpdfl_mixture_builder::config_from_stream(vcl_istream & is)
00641 {
00642 vcl_string s = mbl_parse_block(is);
00643
00644 vcl_istringstream ss(s);
00645 mbl_read_props_type props = mbl_read_props_ws(ss);
00646
00647 double mv=1.0e-6;
00648 if (props.find("min_var")!=props.end())
00649 {
00650 mv=vul_string_atof(props["min_var"]);
00651 props.erase("min_var");
00652 }
00653 set_min_var(mv);
00654
00655 unsigned n_pdfs = 2;
00656 if (props.find("n_pdfs")!=props.end())
00657 {
00658 n_pdfs=vul_string_atoi(props["n_pdfs"]);
00659 props.erase("n_pdfs");
00660 }
00661
00662 max_its_=10;
00663 if (props.find("max_its")!=props.end())
00664 {
00665 max_its_=vul_string_atoi(props["max_its"]);
00666 props.erase("max_its");
00667 }
00668
00669 weights_fixed_=false;
00670 if (props.find("weights_fixed")!=props.end())
00671 {
00672 weights_fixed_=vul_string_to_bool(props["weights_fixed"]);
00673 props.erase("weights_fixed");
00674 }
00675
00676 if (props.find("basis_pdf")!=props.end())
00677 {
00678 vcl_istringstream pdf_ss(props["basis_pdf"]);
00679 vcl_auto_ptr<vpdfl_builder_base>
00680 b = vpdfl_builder_base::new_pdf_builder_from_stream(pdf_ss);
00681 init(*b,n_pdfs);
00682 props.erase("basis_pdf");
00683 }
00684
00685 try
00686 {
00687 mbl_read_props_look_for_unused_props(
00688 "vpdfl_mixture_builder::config_from_stream", props);
00689 }
00690
00691 catch(mbl_exception_unused_props &e)
00692 {
00693 throw mbl_exception_parse_error(e.what());
00694 }
00695 }
00696
00697