contrib/mul/clsfy/clsfy_direct_boost.cxx
Go to the documentation of this file.
00001 #include "clsfy_direct_boost.h"
00002 //:
00003 // \file
00004 // \brief Classifier using adaboost on combinations of simple 1D classifiers
00005 // \author dac
00006 
00007 //=======================================================================
00008 
00009 #include <vcl_string.h>
00010 #include <vcl_iostream.h>
00011 #include <vcl_vector.h>
00012 #include <vcl_cassert.h>
00013 #include <vcl_cmath.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_binary_loader.h>
00016 #include <vsl/vsl_vector_io.h>
00017 #include <vnl/io/vnl_io_matrix.h>
00018 #include <vnl/io/vnl_io_vector.h>
00019 
00020 //=======================================================================
00021 //: Default constructor
00022 clsfy_direct_boost::clsfy_direct_boost()
00023 : n_clfrs_used_(-1) , n_dims_(-1)
00024 {
00025 }
00026 
00027 clsfy_direct_boost::clsfy_direct_boost(const clsfy_direct_boost& c)
00028   : clsfy_classifier_base()
00029 {
00030   *this = c;
00031 }
00032 
00033 //: Copy operator
00034 clsfy_direct_boost& clsfy_direct_boost::operator=(const clsfy_direct_boost& c)
00035 {
00036   delete_stuff();
00037 
00038   int n = c.classifier_1d_.size();
00039   classifier_1d_.resize(n);
00040   for (int i=0;i<n;++i)
00041     classifier_1d_[i] = c.classifier_1d_[i]->clone();
00042 
00043   threshes_ = c.threshes_;
00044   wts_ = c.wts_;
00045   index_ = c.index_;
00046   return *this;
00047 }
00048 
00049 //: Delete objects on heap
00050 void clsfy_direct_boost::delete_stuff()
00051 {
00052   for (unsigned int i=0;i<classifier_1d_.size();++i)
00053     delete classifier_1d_[i];
00054 
00055   classifier_1d_.resize(0);
00056 
00057   threshes_.resize(0);
00058   wts_.resize(0);
00059   index_.resize(0);
00060   n_clfrs_used_= -1;
00061 }
00062 
00063 //: Destructor
00064 clsfy_direct_boost::~clsfy_direct_boost()
00065 {
00066   delete_stuff();
00067 }
00068 
00069 
00070 //: Comparison
00071 bool clsfy_direct_boost::operator==(const clsfy_direct_boost& x) const
00072 {
00073   if (x.classifier_1d_.size() != classifier_1d_.size() ) return false;
00074   int n= x.classifier_1d_.size();
00075   for (int i=0; i<n; ++i)
00076     if (!(*(x.classifier_1d_[i]) == *(classifier_1d_[i]) )) return false;
00077 
00078   return  x.threshes_ == threshes_ &&
00079           x.wts_ == wts_ &&
00080           x.index_ == index_;
00081 }
00082 
00083 
00084 //: Set parameters.  Clones taken of *classifier[i]
00085 void clsfy_direct_boost::set_parameters(
00086                       const vcl_vector<clsfy_classifier_1d*>& classifier,
00087                       const vcl_vector<double>& threshes,
00088                       const vcl_vector<double>& wts,
00089                       const vcl_vector<int>& index)
00090 {
00091   delete_stuff();
00092 
00093   int n = classifier.size();
00094   classifier_1d_.resize(n);
00095   for (int i=0;i<n;++i)
00096     classifier_1d_[i] = classifier[i]->clone();
00097 
00098   threshes_ = threshes;
00099   wts_ = wts;
00100   index_= index;
00101 }
00102 
00103 
00104 //: Clear all wts and classifiers
00105 void clsfy_direct_boost::clear()
00106 {
00107   delete_stuff();
00108 }
00109 
00110 
00111 //: Add weak classifier and alpha value
00112 // nb also changes n_clfrs_used to use all current weak classifiers
00113 // nb calc total threshold (ie threshes_ separately, see below)
00114 void clsfy_direct_boost::add_one_classifier(clsfy_classifier_1d* c1d,
00115                                             double wt,
00116                                             int index)
00117 {
00118   classifier_1d_.push_back(c1d->clone());
00119   wts_.push_back(wt);
00120   index_.push_back(index);
00121   n_clfrs_used_=wts_.size();
00122 }
00123 
00124 
00125 //: Add one threshold
00126 void clsfy_direct_boost::add_one_threshold(double thresh)
00127 {
00128   threshes_.push_back(thresh);
00129 }
00130 
00131 
00132 //: Add final threshold
00133 void clsfy_direct_boost::add_final_threshold(double thresh)
00134 {
00135   int n= threshes_.size();
00136   threshes_[n-1]= thresh;
00137 }
00138 
00139 
00140 //: Classify the input vector.
00141 // Returns either 1 (for positive class) or 0 (for negative class)
00142 unsigned clsfy_direct_boost::classify(const vnl_vector<double> &v) const
00143 {
00144   //vcl_cout<<"wts_.size()= "<<wts_.size()<<vcl_endl
00145   //        <<"n_clfrs_used_= "<<n_clfrs_used_<<vcl_endl;
00146   assert ( n_clfrs_used_ >= 0);
00147   assert ( (unsigned)n_clfrs_used_ <= wts_.size() );
00148   assert ( n_dims_ >= 0);
00149   assert ( v.size() == (unsigned)n_dims_ );
00150 
00151 
00152   double sum = 0.0;
00153   for (int i=0;i<n_clfrs_used_;++i)
00154     sum+= wts_[i]* classifier_1d_[i]->log_l( v[ index_[i] ] );
00155     //sum += wts_[i]*classifier_1d_[i]->classify(v[ index_[i] ]);
00156 
00157   if (sum < threshes_[n_clfrs_used_-1] ) return 1;
00158   return 0;
00159 }
00160 
00161 //=======================================================================
00162 
00163 //: Find the posterior probability of the input being in the positive class.
00164 // The result is outputs(0)
00165 void clsfy_direct_boost::class_probabilities(vcl_vector<double> &outputs,
00166                                              const vnl_vector<double> &input) const
00167 {
00168   outputs.resize(1);
00169   outputs[0] = 1.0 / (1.0 + vcl_exp(-log_l(input)));
00170 }
00171 
00172 //=======================================================================
00173 
00174 //: Log likelihood of being in the positive class.
00175 // Class probability = 1 / (1+exp(-log_l))
00176 double clsfy_direct_boost::log_l(const vnl_vector<double> &v) const
00177 {
00178   assert ( n_clfrs_used_ >= 0);
00179   assert ( (unsigned)n_clfrs_used_ <= wts_.size() );
00180   //assert ( n_dims_ != -1);
00181   //assert ( v.size() == n_dims_ );
00182   double sum = 0.0;
00183   for (int i=0;i<n_clfrs_used_;++i)
00184     sum+= wts_[i]* classifier_1d_[i]->log_l( v[ index_[i] ] );
00185     //sum += wts_[i]*classifier_1d_[i]->classify(v[ index_[i] ]);
00186 
00187   return sum;  // this isn't really a log likelihood, because the lower this
00188                // value the more likely a vector is to be a pos example
00189 }
00190 
00191 //=======================================================================
00192 
00193 vcl_string clsfy_direct_boost::is_a() const
00194 {
00195   return vcl_string("clsfy_direct_boost");
00196 }
00197 
00198 //=======================================================================
00199 
00200 bool clsfy_direct_boost::is_class(vcl_string const& s) const
00201 {
00202   return s == clsfy_direct_boost::is_a() || clsfy_classifier_base::is_class(s);
00203 }
00204 
00205 //=======================================================================
00206 
00207 // required if data is present in this class
00208 void clsfy_direct_boost::print_summary(vcl_ostream& os) const
00209 {
00210   int n = wts_.size();
00211   assert( wts_.size() == index_.size() );
00212   os<<'\n';
00213   for (int i=0;i<n;++i)
00214   {
00215     os<<" Weights: "<<wts_[i]
00216       <<" Index: "<<index_[i]
00217       <<" Total Threshold: "<<threshes_[i]
00218       <<" Classifier: "<<classifier_1d_[i]<<'\n';
00219   }
00220 }
00221 
00222 //=======================================================================
00223 
00224 short clsfy_direct_boost::version_no() const
00225 {
00226   return 1;
00227 }
00228 
00229 //=======================================================================
00230 
00231 void clsfy_direct_boost::b_write(vsl_b_ostream& bfs) const
00232 {
00233   vsl_b_write(bfs,version_no());
00234   vsl_b_write(bfs,classifier_1d_);
00235   vsl_b_write(bfs,threshes_);
00236   vsl_b_write(bfs,wts_);
00237   vsl_b_write(bfs,index_);
00238 }
00239 
00240 //=======================================================================
00241 
00242 void clsfy_direct_boost::b_read(vsl_b_istream& bfs)
00243 {
00244   if (!bfs) return;
00245 
00246   delete_stuff();
00247 
00248   short version;
00249   vsl_b_read(bfs,version);
00250   switch (version)
00251   {
00252     case (1):
00253       vsl_b_read(bfs,classifier_1d_);
00254       vsl_b_read(bfs,threshes_);
00255       vsl_b_read(bfs,wts_);
00256       vsl_b_read(bfs,index_);
00257 
00258       // set default number of classifiers used to be the maximum number
00259       n_clfrs_used_= index_.size();
00260 
00261       break;
00262     default:
00263       vcl_cerr << "I/O ERROR: clsfy_direct_boost::b_read(vsl_b_istream&)\n"
00264                << "           Unknown version number "<< version << '\n';
00265       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00266   }
00267 }