contrib/mul/clsfy/clsfy_adaboost_trainer.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_adaboost_trainer.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 // \brief Functions to train classifiers using AdaBoost algorithm
00008 // \author dac
00009 // \date   Fri Mar  1 23:49:39 2002
00010 //  Functions to train classifiers using AdaBoost algorithm
00011 //  AdaBoost combines a set of (usually simple, weak) classifiers into
00012 //  a more powerful single classifier.  Essentially it selects the
00013 //  classifiers one at a time, choosing the best at each step.
00014 //  The classifiers are trained to distinguish the examples mis-classified
00015 //  by the currently selected classifiers.
00016 // \verbatim
00017 // Modifications
00018 // \endverbatim
00019 
00020 #include "clsfy_adaboost_trainer.h"
00021 
00022 #include <vcl_iostream.h>
00023 #include <vsl/vsl_indent.h>
00024 #include <vcl_cmath.h>
00025 #include <vcl_cassert.h>
00026 
00027 //=======================================================================
00028 
00029 clsfy_adaboost_trainer::clsfy_adaboost_trainer()
00030 {
00031 }
00032 
00033 //=======================================================================
00034 
00035 clsfy_adaboost_trainer::~clsfy_adaboost_trainer()
00036 {
00037 }
00038 
00039 
00040 //: Extracts the j-th element of each vector in data and puts into v
00041 void clsfy_adaboost_trainer::clsfy_get_elements(
00042                               vnl_vector<double>& v,
00043                               mbl_data_wrapper<vnl_vector<double> >& data,
00044                               int j)
00045 {
00046   unsigned long n = data.size();
00047   v.set_size(n);
00048   data.reset();
00049   for (unsigned long i=0;i<n;++i)
00050   {
00051     v[i] = data.current()[j];
00052     data.next();
00053   }
00054 }
00055 
00056 
00057 //: Correctly classified examples have weights scaled by beta
00058 void clsfy_adaboost_trainer::clsfy_update_weights_weak(
00059                               vnl_vector<double> &wts,
00060                               const vnl_vector<double>& data,
00061                               clsfy_classifier_1d& classifier,
00062                               int class_number,
00063                               double beta)
00064 {
00065   assert(class_number >= 0);
00066   unsigned int n = wts.size();
00067   for (unsigned int i=0;i<n;++i)
00068     if (classifier.classify(data[i])==(unsigned)class_number) wts[i]*=beta;
00069 }
00070 
00071 
00072 //: Build classifier composed of 1d classifiers working on individual vector elements
00073 //  Builds an n-component classifier, each component of which is a 1D classifier
00074 //  working on a single element of the input vector.
00075 //  here egs0 are -ve examples
00076 //  and egs1 are +ve examples
00077 void clsfy_adaboost_trainer::build_strong_classifier(
00078                               clsfy_simple_adaboost& strong_classifier,
00079                               int max_n_clfrs,
00080                               clsfy_builder_1d& builder,
00081                               mbl_data_wrapper<vnl_vector<double> >& egs0,
00082                               mbl_data_wrapper<vnl_vector<double> >& egs1)
00083 {
00084   // remove all alphas and classifiers from strong classifier
00085   strong_classifier.clear();
00086 
00087 
00088   clsfy_classifier_1d* c1d = builder.new_classifier();
00089   clsfy_classifier_1d* best_c1d= builder.new_classifier();
00090 
00091   unsigned long n0 = egs0.size();
00092   unsigned long n1 = egs1.size();
00093   int n=max_n_clfrs;
00094 
00095   // Dimensionality of data
00096   unsigned int d = egs0.current().size();
00097   strong_classifier.set_n_dims(d);
00098 
00099   // Initialise the weights on each sample
00100   vnl_vector<double> wts0(n0,0.5/n0);
00101   vnl_vector<double> wts1(n1,0.5/n1);
00102 
00103   vnl_vector<double> egs0_1d, egs1_1d;
00104 
00105   for (int i=0;i<n;++i)
00106   {
00107     vcl_cout<<"adaboost training round = "<<i<<'\n';
00108 
00109     //vcl_cout<<"wts0= "<<wts0<<"\nwts1= "<<wts1<<'\n';
00110 
00111     int best_j=-1;
00112     double min_error= 100000;
00113     for (unsigned int j=0;j<d;++j)
00114     {
00115       //vcl_cout<<"building classifier "<<j<<" of "<<d<<'\n';
00116       clsfy_get_elements(egs0_1d,egs0,j);
00117       clsfy_get_elements(egs1_1d,egs1,j);
00118 
00119       double error = builder.build(*c1d,egs0_1d,wts0,egs1_1d,wts1);
00120       //vcl_cout<<"error= "<<error<<'\n';
00121       if (j==0 || error<min_error)
00122       {
00123         min_error = error;
00124         delete best_c1d;
00125         best_c1d= c1d->clone();
00126         best_j = j;
00127       }
00128     }
00129 
00130     vcl_cout<<"best_j= "<<best_j<<'\n'
00131             <<"min_error= "<<min_error<<'\n';
00132 
00133     if (min_error<1e-10)  // Hooray!
00134     {
00135       vcl_cout<<"min_error<1e-10 !!!\n";
00136       double alpha  = vcl_log(2.0*(n0+n1));   //is this appropriate???
00137       strong_classifier.add_classifier( best_c1d, alpha, best_j);
00138 
00139       // delete classifiers on heap, because clones taken by strong_classifier
00140       delete c1d;
00141       delete best_c1d;
00142       return;
00143     }
00144 
00145 
00146     if (0.5-min_error<1e-10) // Oh dear, no further improvement possible
00147     {
00148       vcl_cout<<"min_error => 0.5 !!!\n";
00149 
00150       // delete classifiers on heap, because clones taken by strong_classifier
00151       delete c1d;
00152       delete best_c1d;
00153       return;
00154     }
00155 
00156     double beta = min_error/(1.0-min_error);
00157     double alpha  = -1.0*vcl_log(beta);
00158     strong_classifier.add_classifier( best_c1d, alpha, best_j);
00159 
00160     if (i<(n-1))
00161     {
00162       // apply the best weak classifier
00163       clsfy_get_elements(egs0_1d,egs0,best_j);
00164       clsfy_get_elements(egs1_1d,egs1,best_j);
00165 
00166       clsfy_update_weights_weak(wts0,egs0_1d,*best_c1d,0,beta);
00167       clsfy_update_weights_weak(wts1,egs1_1d,*best_c1d,1,beta);
00168 
00169       // normalise the weights
00170       double w_sum = wts0.mean()*n0 + wts1.mean()*n1;
00171       wts0/=w_sum;
00172       wts1/=w_sum;
00173     }
00174   }
00175 
00176   delete c1d;
00177   delete best_c1d;
00178 }
00179 
00180 //=======================================================================
00181 
00182 short clsfy_adaboost_trainer::version_no() const
00183 {
00184   return 1;
00185 }
00186 
00187 //=======================================================================
00188 
00189 vcl_string clsfy_adaboost_trainer::is_a() const
00190 {
00191   return vcl_string("clsfy_adaboost_trainer");
00192 }
00193 
00194 bool clsfy_adaboost_trainer::is_class(vcl_string const& s) const
00195 {
00196   return s == clsfy_adaboost_trainer::is_a();
00197 }
00198 
00199 //=======================================================================
00200 
00201 #if 0
00202 
00203 // required if data stored on the heap is present in this class
00204 clsfy_adaboost_trainer::clsfy_adaboost_trainer(const clsfy_adaboost_trainer& new_b):
00205   data_ptr_(0)
00206 {
00207   *this = new_b;
00208 }
00209 
00210 //=======================================================================
00211 
00212 // required if data stored on the heap is present in this class
00213 clsfy_adaboost_trainer& clsfy_adaboost_trainer::operator=(const clsfy_adaboost_trainer& new_b)
00214 {
00215   if (&new_b==this) return *this;
00216 
00217   // Copy heap member variables.
00218   delete data_ptr_; data_ptr_=0;
00219 
00220   if (new_b.data_ptr_)
00221     data_ptr_ = new_b.data_ptr_->clone();
00222 
00223   // Copy normal member variables
00224   data_ = new_b.data_;
00225 
00226   return *this;
00227 }
00228 
00229 #endif // 0
00230 
00231 //=======================================================================
00232 
00233 // required if data is present in this class
00234 void clsfy_adaboost_trainer::print_summary(vcl_ostream& /*os*/) const
00235 {
00236     // os << data_; // example of data output
00237     vcl_cerr << "clsfy_adaboost_trainer::print_summary() NYI\n";
00238 }
00239 
00240 //=======================================================================
00241 
00242 // required if data is present in this class
00243 void clsfy_adaboost_trainer::b_write(vsl_b_ostream& /*bfs*/) const
00244 {
00245   //vsl_b_write(bfs, version_no());
00246   //vsl_b_write(bfs, data_);
00247   vcl_cerr << "clsfy_adaboost_trainer::b_write() NYI\n";
00248 }
00249 
00250 //=======================================================================
00251 
00252 // required if data is present in this class
00253 void clsfy_adaboost_trainer::b_read(vsl_b_istream& /*bfs*/)
00254 {
00255   vcl_cerr << "clsfy_adaboost_trainer::b_read() NYI\n";
00256 #if 0
00257   if (!bfs) return;
00258 
00259   short version;
00260   vsl_b_read(bfs,version);
00261   switch (version)
00262   {
00263   case (1):
00264     vsl_b_read(bfs,data_);
00265     break;
00266   default:
00267     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_adaboost_trainer&)\n"
00268              << "           Unknown version number "<< version << '\n';
00269     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00270     return;
00271   }
00272 #endif
00273 }
00274 
00275 //=======================================================================
00276 
00277 void vsl_b_write(vsl_b_ostream& bfs, const clsfy_adaboost_trainer& b)
00278 {
00279   b.b_write(bfs);
00280 }
00281 
00282 //=======================================================================
00283 
00284 void vsl_b_read(vsl_b_istream& bfs, clsfy_adaboost_trainer& b)
00285 {
00286   b.b_read(bfs);
00287 }
00288 
00289 //=======================================================================
00290 
00291 void vsl_print_summary(vcl_ostream& os,const clsfy_adaboost_trainer& b)
00292 {
00293   os << b.is_a() << ": ";
00294   vsl_indent_inc(os);
00295   b.print_summary(os);
00296   vsl_indent_dec(os);
00297 }
00298 
00299 //=======================================================================
00300 
00301 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_adaboost_trainer& b)
00302 {
00303   vsl_print_summary(os,b);
00304   return os;
00305 }