00001
00002 #include "pdf1d_mixture_builder.h"
00003
00004
00005
00006
00007
00008 #include <vcl_cassert.h>
00009 #include <vcl_cmath.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_cstdlib.h>
00012 #include <vsl/vsl_indent.h>
00013 #include <vsl/vsl_vector_io.h>
00014 #include <vsl/vsl_binary_loader.h>
00015 #include <pdf1d/pdf1d_mixture.h>
00016 #include <mbl/mbl_data_wrapper.h>
00017 #include <mbl/mbl_data_array_wrapper.h>
00018
00019
00020 const double min_wt = 1e-8;
00021
00022
00023 void pdf1d_mixture_builder::init()
00024 {
00025 min_var_ = 1.0e-6;
00026 max_its_ = 10;
00027 weights_fixed_ = false;
00028 }
00029
00030
00031
00032
00033 pdf1d_mixture_builder::pdf1d_mixture_builder()
00034 {
00035 init();
00036 }
00037
00038
00039 pdf1d_mixture_builder::pdf1d_mixture_builder(const pdf1d_mixture_builder& b):
00040 pdf1d_builder()
00041 {
00042 init();
00043 *this = b;
00044 }
00045
00046
00047 pdf1d_mixture_builder& pdf1d_mixture_builder::operator=(const pdf1d_mixture_builder& b)
00048 {
00049 if (&b==this) return *this;
00050
00051 delete_stuff();
00052
00053 int n = b.builder_.size();
00054 builder_.resize(n);
00055 for (int i=0;i<n;++i)
00056 builder_[i] = b.builder_[i]->clone();
00057
00058 min_var_ = b.min_var_;
00059 max_its_ = b.max_its_;
00060 weights_fixed_ = b.weights_fixed_;
00061
00062 return *this;
00063 }
00064
00065
00066
00067 void pdf1d_mixture_builder::delete_stuff()
00068 {
00069 int n = builder_.size();
00070 for (int i=0;i<n;++i)
00071 delete builder_[i];
00072 builder_.resize(0);
00073 }
00074
00075 pdf1d_mixture_builder::~pdf1d_mixture_builder()
00076 {
00077 delete_stuff();
00078 }
00079
00080
00081
00082
00083
00084 void pdf1d_mixture_builder::init(pdf1d_builder& builder, int n)
00085 {
00086 delete_stuff();
00087 builder_.resize(n);
00088 for (int i=0;i<n;++i)
00089 builder_[i] = builder.clone();
00090 }
00091
00092
00093
00094
00095 void pdf1d_mixture_builder::set_max_iterations(int n)
00096 {
00097 max_its_ = n;
00098 }
00099
00100
00101 void pdf1d_mixture_builder::set_weights_fixed(bool b)
00102 {
00103 weights_fixed_ = b;
00104 }
00105
00106
00107
00108
00109 pdf1d_pdf* pdf1d_mixture_builder::new_model() const
00110 {
00111 return new pdf1d_mixture;
00112 }
00113
00114 vcl_string pdf1d_mixture_builder::new_model_type() const
00115 {
00116 return vcl_string("pdf1d_mixture");
00117 }
00118
00119
00120
00121
00122 void pdf1d_mixture_builder::set_min_var(double min_var)
00123 {
00124 min_var_ = min_var;
00125 }
00126
00127
00128
00129
00130 double pdf1d_mixture_builder::min_var() const
00131 {
00132 return min_var_;
00133 }
00134
00135
00136
00137
00138 void pdf1d_mixture_builder::build(pdf1d_pdf& , double ) const
00139 {
00140 vcl_cerr<<"pdf1d_mixture_builder::build(model,mean) not yet implemented.\n";
00141 vcl_abort();
00142 }
00143
00144
00145
00146
00147 void pdf1d_mixture_builder::build(pdf1d_pdf& model,
00148 mbl_data_wrapper<double>& data) const
00149 {
00150 vcl_vector<double> wts(data.size());
00151 vcl_fill(wts.begin(),wts.end(),1.0);
00152 weighted_build(model,data,wts);
00153 }
00154
00155
00156
00157
00158 void pdf1d_mixture_builder::weighted_build(pdf1d_pdf& base_model,
00159 mbl_data_wrapper<double>& data,
00160 const vcl_vector<double>& wts) const
00161 {
00162 assert(base_model.is_class("pdf1d_mixture"));
00163 pdf1d_mixture& model = static_cast<pdf1d_mixture&>(base_model);
00164
00165 unsigned int n = builder_.size();
00166
00167 bool model_setup = (model.n_components()==n);
00168
00169 if (!model_setup)
00170 {
00171
00172 model.clear();
00173 model.components().resize(n);
00174 model.weights().resize(n);
00175 for (unsigned int i=0;i<n;++i)
00176 {
00177 builder_[i]->set_min_var(min_var_);
00178 model.components()[i] = builder_[i]->new_model();
00179 model.weights()[i] = 1.0/n;
00180 }
00181 }
00182
00183
00184 const double* data_ptr;
00185 vcl_vector<double> data_array;
00186
00187 {
00188 int n=data.size();
00189 data.reset();
00190 data_array.resize(n);
00191 for (int i=0;i<n;++i)
00192 {
00193 data_array[i] = data.current();
00194 data.next();
00195 }
00196
00197 data_ptr = &data_array[0];
00198 }
00199
00200 if (!model_setup)
00201 initialise(model,data_ptr,wts);
00202
00203 vcl_vector<vnl_vector<double> > probs;
00204
00205 int n_its = 0;
00206 double max_move = 1e-6;
00207 double move = max_move+1;
00208 while (move>max_move && n_its<max_its_)
00209 {
00210 e_step(model,probs,data_ptr,wts);
00211 move = m_step(model,probs,data_ptr,wts);
00212 n_its++;
00213 }
00214 calc_mean_and_variance(model);
00215 assert(model.is_valid_pdf());
00216 }
00217
00218 static void UpdateRange(double& min_v, double& max_v, double v)
00219 {
00220 if (v<min_v) min_v=v;
00221 else
00222 if (v>max_v) max_v=v;
00223 }
00224
00225
00226 void pdf1d_mixture_builder::initialise(pdf1d_mixture& model,
00227 const double* data,
00228 const vcl_vector<double>& wts) const
00229 {
00230
00231 int n_comp = builder_.size();
00232 int n_samples = wts.size();
00233
00234 vcl_vector<double> wts_i(n_samples);
00235
00236
00237 double min_v = data[0];
00238 double max_v = min_v;
00239 for (int i=1;i<n_samples;++i)
00240 UpdateRange(min_v,max_v,data[i]);
00241
00242
00243 vcl_vector<double> mean(n_comp);
00244 for (int i=0;i<n_comp;++i)
00245 {
00246 double f = (i+1.0)/(n_comp+1);
00247 mean[i] = (1-f)*min_v + f*max_v;
00248 }
00249
00250 double mean_sep = (max_v-min_v)/n_samples;
00251
00252 mbl_data_array_wrapper<double> data_array(data,n_samples);
00253
00254 for (int i=0;i<n_comp;++i)
00255 {
00256
00257 double w_sum = 0.0;
00258 for (int j=0;j<n_samples;++j)
00259 {
00260 wts_i[j] = mean_sep/(mean_sep + vcl_fabs(data[j]-mean[i]));
00261 w_sum+=wts_i[j];
00262 }
00263
00264
00265 double f = double(n_samples)/(n_comp*w_sum);
00266 for (int j=0;j<n_samples;++j) wts_i[j]*=f;
00267
00268
00269 builder_[i]->weighted_build(*(model.components()[i]),data_array,wts_i);
00270 }
00271 }
00272
00273
00274 void pdf1d_mixture_builder::e_step(pdf1d_mixture& model,
00275 vcl_vector<vnl_vector<double> >& probs,
00276 const double* data,
00277 const vcl_vector<double>& wts) const
00278 {
00279 unsigned int n_comp = builder_.size();
00280 unsigned int n_egs = wts.size();
00281 const vcl_vector<double>& m_wts = model.weights();
00282
00283 if (probs.size()!=n_comp) probs.resize(n_comp);
00284
00285
00286
00287 for (unsigned int i=0;i<n_comp;++i)
00288 {
00289 if (probs[i].size()!=n_egs) probs[i].set_size(n_egs);
00290
00291
00292
00293 if (m_wts[i]<=0) continue;
00294
00295 double *p_data = probs[i].begin();
00296 double log_wt_i = vcl_log(model.weights()[i]);
00297
00298 for (unsigned int j=0;j<n_egs;++j)
00299 p_data[j] = log_wt_i + model.components()[i]->log_p(data[j]);
00300 }
00301
00302
00303
00304 for (unsigned int j=0;j<n_egs;++j)
00305 {
00306
00307 double max_log_p = 0;
00308 for (unsigned int i=0;i<n_comp;++i)
00309 {
00310 if (m_wts[i]<=0) continue;
00311 if (i==0 || probs[i](j)>max_log_p) max_log_p = probs[i](j);
00312 }
00313
00314
00315 double sum = 0.0;
00316 for (unsigned int i=0;i<n_comp;++i)
00317 {
00318 if (m_wts[i]<=0) continue;
00319 double p = vcl_exp(probs[i](j)-max_log_p);
00320 probs[i](j) = p;
00321 sum+=p;
00322 }
00323
00324
00325 if (sum>0.0)
00326 for (unsigned int i=0;i<n_comp;++i)
00327 probs[i](j)/=sum;
00328 }
00329 }
00330
00331
00332 double pdf1d_mixture_builder::m_step(pdf1d_mixture& model,
00333 const vcl_vector<vnl_vector<double> >& probs,
00334 const double* data,
00335 const vcl_vector<double>& wts) const
00336 {
00337 int n_comp = builder_.size();
00338 int n_egs = wts.size();
00339 vcl_vector<double> wts_i(n_egs);
00340
00341 mbl_data_array_wrapper<double> data_array(data,n_egs);
00342
00343 double move = 0.0;
00344 double old_mean;
00345
00346 if (!weights_fixed_)
00347 {
00348 double w_sum = 0.0;
00349
00350 for (int i=0;i<n_comp;++i)
00351 {
00352 model.weights()[i]=probs[i].mean();
00353
00354
00355 if (model.weights()[i]<min_wt) model.weights()[i]=0.0;
00356
00357 w_sum += model.weights()[i];
00358 }
00359
00360
00361 for (int i=0;i<n_comp;++i)
00362 model.weights()[i]/=w_sum;
00363 }
00364
00365 for (int i=0;i<n_comp;++i)
00366 {
00367
00368
00369 if (model.weights()[i]<=0) continue;
00370
00371
00372 const double* p = probs[i].begin();
00373 for (int j=0;j<n_egs;++j)
00374 wts_i[j] = wts[j]*p[j];
00375
00376 old_mean = model.components()[i]->mean();
00377 builder_[i]->weighted_build(*(model.components()[i]), data_array, wts_i);
00378
00379 move += vcl_fabs(old_mean-model.components()[i]->mean());
00380 }
00381
00382
00383 return move;
00384 }
00385
00386
00387
00388
00389 void pdf1d_mixture_builder::calc_mean_and_variance(pdf1d_mixture& model)
00390 {
00391 double sum = 0;
00392 double sum2 = 0;
00393
00394 unsigned i;
00395 for (i=0; i<model.n_components(); ++i)
00396 {
00397 double wi = model.weight(i);
00398 double mean_i = model.component(i).mean();
00399 sum += mean_i * wi;
00400 sum2 += (model.component(i).variance()+mean_i*mean_i)*wi;
00401 }
00402
00403 double mean = sum;
00404 double var = sum2-mean*mean;
00405
00406 model.set_mean_and_variance(mean, var);
00407 }
00408
00409
00410
00411 vcl_string pdf1d_mixture_builder::is_a() const
00412 {
00413 return vcl_string("pdf1d_mixture_builder");
00414 }
00415
00416
00417
00418 bool pdf1d_mixture_builder::is_class(vcl_string const& s) const
00419 {
00420 return pdf1d_builder::is_class(s) || s==pdf1d_mixture_builder::is_a();
00421 }
00422
00423
00424
00425 short pdf1d_mixture_builder::version_no() const
00426 {
00427 return 1;
00428 }
00429
00430
00431
00432 pdf1d_builder* pdf1d_mixture_builder::clone() const
00433 {
00434 return new pdf1d_mixture_builder(*this);
00435 }
00436
00437
00438
00439 void pdf1d_mixture_builder::print_summary(vcl_ostream& os) const
00440 {
00441 for (unsigned int i=0;i<builder_.size();++i)
00442 {
00443 os<<'\n'<<vsl_indent()<<"Builder "<<i<<": ";
00444 vsl_print_summary(os, builder_[i]);
00445 }
00446 os<<'\n';
00447 }
00448
00449
00450
00451 void pdf1d_mixture_builder::b_write(vsl_b_ostream& bfs) const
00452 {
00453 vsl_b_write(bfs,is_a());
00454 vsl_b_write(bfs,version_no());
00455 vsl_b_write(bfs,builder_);
00456 vsl_b_write(bfs,max_its_);
00457 vsl_b_write(bfs,weights_fixed_);
00458 }
00459
00460
00461
00462 void pdf1d_mixture_builder::b_read(vsl_b_istream& bfs)
00463 {
00464 if (!bfs) return;
00465
00466 vcl_string name;
00467 vsl_b_read(bfs,name);
00468 if (name != is_a())
00469 {
00470 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, pdf1d_mixture_builder &)\n"
00471 << " Attempted to load object of type "
00472 << name <<" into object of type " << is_a() << vcl_endl;
00473 bfs.is().clear(vcl_ios::badbit);
00474 return;
00475 }
00476
00477 delete_stuff();
00478
00479 short version;
00480 vsl_b_read(bfs,version);
00481 switch (version)
00482 {
00483 case (1):
00484 vsl_b_read(bfs,builder_);
00485 vsl_b_read(bfs,max_its_);
00486 vsl_b_read(bfs,weights_fixed_);
00487 break;
00488 default:
00489 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, pdf1d_mixture_builder &)\n"
00490 << " Unknown version number "<< version << vcl_endl;
00491 bfs.is().clear(vcl_ios::badbit);
00492 return;
00493 }
00494 }
00495
00496