contrib/mul/clsfy/clsfy_mean_square_1d_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_mean_square_1d_builder.cxx
00002 #include "clsfy_mean_square_1d_builder.h"
00003 //:
00004 // \file
00005 // \author dac
00006 // \date   Tue Mar  5 01:11:31 2002
00007 
00008 #include <vcl_cmath.h>
00009 #include <vcl_iostream.h>
00010 #include <vcl_string.h>
00011 #include <vcl_cassert.h>
00012 #include <vcl_cstdlib.h>
00013 #include <vsl/vsl_binary_loader.h>
00014 #include <vnl/vnl_double_2.h>
00015 #include <clsfy/clsfy_builder_1d.h>
00016 #include <clsfy/clsfy_mean_square_1d.h>
00017 #include <vcl_algorithm.h>
00018 
00019 //=======================================================================
00020 
00021 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder()
00022 {
00023 }
00024 
00025 //=======================================================================
00026 
00027 clsfy_mean_square_1d_builder::~clsfy_mean_square_1d_builder()
00028 {
00029 }
00030 
00031 //=======================================================================
00032 
00033 short clsfy_mean_square_1d_builder::version_no() const
00034 {
00035   return 1;
00036 }
00037 
00038 
00039 //: Create empty classifier
00040 // Caller is responsible for deletion
00041 clsfy_classifier_1d* clsfy_mean_square_1d_builder::new_classifier() const
00042 {
00043   return new clsfy_mean_square_1d();
00044 }
00045 
00046 
00047 //: Build a binary_threshold classifier
00048 //  Train classifier, returning weighted error
00049 //  Selects parameters of classifier which best separate examples from two classes,
00050 //  weighting examples appropriately when estimating the misclassification rate.
00051 //  Returns weighted sum of error, e.wts, where e_i =0 for correct classifications,
00052 //  e_i=1 for incorrect.
00053 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
00054                                            const vnl_vector<double>& egs,
00055                                            const vnl_vector<double>& wts,
00056                                            const vcl_vector<unsigned> &outputs) const
00057 {
00058   // this method sorts the data and passes it to the method below
00059   assert(classifier.is_class("clsfy_mean_square_1d"));
00060 
00061   unsigned int n = egs.size();
00062   assert ( wts.size() == n );
00063   assert ( outputs.size() == n );
00064 
00065   // calc weighted mean of positive data
00066   double wm_pos= 0.0;
00067   double tot_pos_wts=0.0, tot_neg_wts=0.0;
00068   unsigned int n_pos=0, n_neg=0;
00069   for (unsigned int i=0; i<n; ++i)
00070   {
00071 #ifdef DEBUG
00072     vcl_cout<<"egs["<<i<<"]= "<<egs[i]<<vcl_endl
00073             <<"wts["<<i<<"]= "<<wts[i]<<vcl_endl
00074             <<"outputs["<<i<<"]= "<<outputs[i]<<vcl_endl;
00075 #endif
00076     if ( outputs[i] == 1 )
00077     {
00078       //vcl_cout<<"wm_pos= "<<wm_pos<<vcl_endl;
00079       wm_pos+= wts(i)*egs(i);
00080       tot_pos_wts+= wts(i);
00081       ++n_pos;
00082     }
00083     else
00084     {
00085       tot_neg_wts+= wts(i);
00086       ++n_neg;
00087     }
00088   }
00089 
00090   assert( n_pos+n_neg== n );
00091   wm_pos/=tot_pos_wts;
00092 #ifdef DEBUG
00093   vcl_cout<<"wm_pos= "<<wm_pos<<vcl_endl;
00094 #endif
00095   // create triples data, so can sort
00096   vcl_vector<vbl_triple<double,int,int> > data;
00097 
00098   vbl_triple<double,int,int> t;
00099   // add data to triples
00100   for (unsigned int i=0;i<n;++i)
00101   {
00102     double k= wm_pos-egs[i];
00103     t.first=k*k;
00104     t.second= outputs[i];
00105     t.third = i;
00106     data.push_back(t);
00107   }
00108 
00109   vbl_triple<double,int,int> *data_ptr=&data[0];
00110   vcl_sort(data_ptr,data_ptr+n);
00111 
00112   double wt_pos=0;
00113   double wt_neg=0;
00114   double min_error= 1000000;
00115   double min_thresh= -1;
00116   for (unsigned int i=0;i<n;++i)
00117   {
00118     if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
00119     else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
00120     else
00121     {
00122       vcl_cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
00123               <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
00124               <<"data.second="<<data[i].second<<vcl_endl;
00125       vcl_abort();
00126     }
00127     double error= tot_pos_wts-wt_pos+wt_neg;
00128 #ifdef DEBUG
00129     vcl_cout<<"data[i].first= "<<data[i].first<<vcl_endl
00130             <<"data[i].second= "<<data[i].second<<vcl_endl
00131             <<"data[i].third= "<<data[i].third<<vcl_endl
00132 
00133             <<"wt_pos= "<<wt_pos<<vcl_endl
00134             <<"tot_wts1= "<<tot_wts1<<vcl_endl
00135             <<"wt_neg= "<<wt_neg<<vcl_endl
00136 
00137             <<"error= "<<error<<vcl_endl;
00138 #endif
00139     if ( error< min_error )
00140     {
00141       min_error= error;
00142       min_thresh = data[i].first + 0.001 ;
00143     }
00144   }
00145 
00146   assert( vcl_fabs (wt_pos - tot_pos_wts) < 1e-9 );
00147   assert( vcl_fabs (wt_neg - tot_neg_wts) < 1e-9 );
00148 #ifdef DEBUG
00149   vcl_cout<<"min_error= "<<min_error<<vcl_endl
00150           <<"min_thresh= "<<min_thresh<<vcl_endl;
00151 #endif
00152   // pass parameters to classifier
00153   classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
00154   return min_error;
00155 }
00156 
00157 
00158 //: Build a mean_square classifier
00159 // nb here egs0 are -ve examples
00160 // and egs1 are +ve examples
00161 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
00162                                            vnl_vector<double>& egs0,
00163                                            vnl_vector<double>& wts0,
00164                                            vnl_vector<double>& egs1,
00165                                            vnl_vector<double>& wts1)  const
00166 {
00167   // this method sorts the data and passes it to the method below
00168   assert(classifier.is_class("clsfy_mean_square_1d"));
00169 
00170   // find mean of positive data (ie egs1) then calc square distance from mean
00171   // for each example
00172   unsigned int n0 = egs0.size();
00173   unsigned int n1 = egs1.size();
00174   assert (wts0.size() == n0 );
00175   assert (wts1.size() == n1 );
00176 
00177   // calc weighted mean of positive data
00178   double tot_wts1= wts1.mean()*n1;
00179   double wm_pos=0.0;
00180   for (unsigned int i=0; i< n1; ++i)
00181   {
00182     wm_pos+= wts1(i)*egs1(i);
00183 #ifdef DEBUG
00184     vcl_cout<<"egs1("<<i<<")= "<<egs1(i)<<vcl_endl
00185             <<"wts1("<<i<<")= "<<wts1(i)<<vcl_endl;
00186 #endif
00187   }
00188   wm_pos/=tot_wts1;
00189 
00190   vcl_cout<<"wm_pos= "<<wm_pos<<vcl_endl;
00191 
00192   vcl_vector<vbl_triple<double,int,int> > data;
00193 
00194   vnl_vector<double> wts(n0+n1);
00195   vbl_triple<double,int,int> t;
00196   // add data for class 0
00197   for (unsigned int i=0;i<n0;++i)
00198   {
00199     double k= wm_pos-egs0[i];
00200     t.first=k*k;
00201     t.second=0;
00202     t.third = i;
00203     wts(i)= wts0[i];
00204     data.push_back(t);
00205   }
00206 
00207   // add data for class 1
00208   for (unsigned int i=0;i<n1;++i)
00209   {
00210     double k= wm_pos-egs1[i];
00211     t.first=k*k;
00212     t.second=1;
00213     t.third = i+n0;
00214     wts(i+n0)= wts1[i];
00215     data.push_back(t);
00216   }
00217 
00218   unsigned int n=n0+n1;
00219 
00220   vbl_triple<double,int,int> *data_ptr=&data[0];
00221   vcl_sort(data_ptr,data_ptr+n);
00222 
00223   double wt_pos=0;
00224   double wt_neg=0;
00225   double min_error= 1000000;
00226   double min_thresh= -1;
00227   for (unsigned int i=0;i<n;++i)
00228   {
00229     if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
00230     else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
00231     else
00232     {
00233       vcl_cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
00234               <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
00235               <<"data.second="<<data[i].second<<vcl_endl;
00236       vcl_abort();
00237     }
00238     double error= tot_wts1-wt_pos+wt_neg;
00239 #ifdef DEBUG
00240     vcl_cout<<"data[i].first= "<<data[i].first<<vcl_endl
00241             <<"data[i].second= "<<data[i].second<<vcl_endl
00242             <<"data[i].third= "<<data[i].third<<vcl_endl
00243 
00244             <<"wt_pos= "<<wt_pos<<vcl_endl
00245             <<"tot_wts1= "<<tot_wts1<<vcl_endl
00246             <<"wt_neg= "<<wt_neg<<vcl_endl
00247 
00248             <<"error= "<<error<<vcl_endl;
00249 #endif
00250     if ( error< min_error )
00251     {
00252       min_error= error;
00253       min_thresh = data[i].first + 0.001 ;
00254     }
00255   }
00256 
00257   assert( vcl_fabs (wt_pos - tot_wts1) < 1e-9 );
00258   assert( vcl_fabs (wt_neg - wts0.mean()*n0) < 1e-9 );
00259   vcl_cout<<"min_error= "<<min_error<<vcl_endl
00260           <<"min_thresh= "<<min_thresh<<vcl_endl;
00261 
00262   // pass parameters to classifier
00263   classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
00264   return min_error;
00265 }
00266 
00267 
00268 //: Train classifier, returning weighted error
00269 //   Assumes two classes
00270 double clsfy_mean_square_1d_builder::build_from_sorted_data(
00271                                   clsfy_classifier_1d& /*classifier*/,
00272                                   const vbl_triple<double,int,int>* /*data*/,
00273                                   const vnl_vector<double>& /*wts*/
00274                                   ) const
00275 {
00276   vcl_cout<<"ERROR: clsfy_mean_square_1d_builder::build_from_sorted_data()\n"
00277           <<"Function not implemented because can't use pre-sorted data\n"
00278           <<"the weighted mean of the data is needed to calc the ordering!\n";
00279   vcl_abort();
00280 
00281   return 0.0;
00282 }
00283 
00284 //=======================================================================
00285 
00286 vcl_string clsfy_mean_square_1d_builder::is_a() const
00287 {
00288   return vcl_string("clsfy_mean_square_1d_builder");
00289 }
00290 
00291 bool clsfy_mean_square_1d_builder::is_class(vcl_string const& s) const
00292 {
00293   return s == clsfy_mean_square_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
00294 }
00295 
00296 //=======================================================================
00297 
00298 #if 0 // two functions commented out
00299 
00300 // required if data stored on the heap is present in this derived class
00301 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder(
00302                              const clsfy_mean_square_1d_builder& new_b) :
00303   data_ptr_(0)
00304 {
00305   *this = new_b;
00306 }
00307 
00308 //=======================================================================
00309 
00310 // required if data stored on the heap is present in this derived class
00311 clsfy_mean_square_1d_builder&
00312 clsfy_mean_square_1d_builder::operator=(const clsfy_mean_square_1d_builder& new_b)
00313 {
00314   if (&new_b==this) return *this;
00315 
00316   // Copy heap member variables.
00317   delete data_ptr_; data_ptr_=0;
00318 
00319   if (new_b.data_ptr_)
00320     data_ptr_ = new_b.data_ptr_->clone();
00321 
00322   // Copy normal member variables
00323   data_ = new_b.data_;
00324 
00325   return *this;
00326 }
00327 
00328 #endif // 0
00329 
00330 //=======================================================================
00331 
00332 clsfy_builder_1d* clsfy_mean_square_1d_builder::clone() const
00333 {
00334   return new clsfy_mean_square_1d_builder(*this);
00335 }
00336 
00337 //=======================================================================
00338 
00339 // required if data is present in this base class
00340 void clsfy_mean_square_1d_builder::print_summary(vcl_ostream& /*os*/) const
00341 {
00342   // clsfy_builder_1d::print_summary(os); // Uncomment this line if it has one.
00343   // vsl_print_summary(os, data_); // Example of data output
00344 
00345   vcl_cerr << "clsfy_mean_square_1d_builder::print_summary() NYI\n";
00346 }
00347 
00348 //=======================================================================
00349 
00350 // required if data is present in this base class
00351 void clsfy_mean_square_1d_builder::b_write(vsl_b_ostream& /*bfs*/) const
00352 {
00353   //vsl_b_write(bfs, version_no());
00354   //clsfy_builder_1d::b_write(bfs);  // Needed if base has any data
00355   //vsl_b_write(bfs, data_);
00356   vcl_cerr << "clsfy_mean_square_1d_builder::b_write() NYI\n";
00357 }
00358 
00359 //=======================================================================
00360 
00361 // required if data is present in this base class
00362 void clsfy_mean_square_1d_builder::b_read(vsl_b_istream& /*bfs*/)
00363 {
00364   vcl_cerr << "clsfy_mean_square_1d_builder::b_read() NYI\n";
00365 #if 0
00366   if (!bfs) return;
00367 
00368   short version;
00369   vsl_b_read(bfs,version);
00370   switch (version)
00371   {
00372   case (1):
00373     //clsfy_builder_1d::b_read(bfs);  // Needed if base has any data
00374     vsl_b_read(bfs,data_);
00375     break;
00376   default:
00377     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_mean_square_1d_builder&)\n"
00378              << "           Unknown version number "<< version << '\n';
00379     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00380     return;
00381   }
00382 #endif
00383 }