contrib/mul/clsfy/clsfy_binary_threshold_1d_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_threshold_1d_builder.cxx
00002 #include "clsfy_binary_threshold_1d_builder.h"
00003 //:
00004 // \file
00005 // \author dac
00006 // \date   Tue Mar  5 01:11:31 2002
00007 
00008 #include <vcl_iostream.h>
00009 #include <vcl_string.h>
00010 #include <vcl_cassert.h>
00011 #include <vsl/vsl_binary_loader.h>
00012 #include <vnl/vnl_double_2.h>
00013 #include <clsfy/clsfy_builder_1d.h>
00014 #include <clsfy/clsfy_binary_threshold_1d.h>
00015 #include <vcl_algorithm.h>
00016 
00017 //=======================================================================
00018 
00019 clsfy_binary_threshold_1d_builder::clsfy_binary_threshold_1d_builder()
00020 {
00021 }
00022 
00023 //=======================================================================
00024 
00025 clsfy_binary_threshold_1d_builder::~clsfy_binary_threshold_1d_builder()
00026 {
00027 }
00028 
00029 //=======================================================================
00030 
00031 short clsfy_binary_threshold_1d_builder::version_no() const
00032 {
00033   return 1;
00034 }
00035 
00036 
00037 //: Create empty classifier
00038 // Caller is responsible for deletion
00039 clsfy_classifier_1d* clsfy_binary_threshold_1d_builder::new_classifier() const
00040 {
00041   return new clsfy_binary_threshold_1d();
00042 }
00043 
00044 
00045 //: Build a binary_threshold classifier
00046 //  Train classifier, returning weighted error
00047 //  Selects parameters of classifier which best separate examples from two classes,
00048 //  weighting examples appropriately when estimating the misclassification rate.
00049 //  Returns weighted sum of error, e.wts, where e_i =0 for correct classifications,
00050 //  e_i=1 for incorrect.
00051 double clsfy_binary_threshold_1d_builder::build(clsfy_classifier_1d& classifier,
00052                                                 const vnl_vector<double>& egs,
00053                                                 const vnl_vector<double>& wts,
00054                                                 const vcl_vector<unsigned> &outputs) const
00055 {
00056   // this method sorts the data and passes it to the method below
00057   assert(classifier.is_class("clsfy_binary_threshold_1d"));
00058 
00059   unsigned int n= egs.size();
00060   assert ( wts.size() == n );
00061   assert ( outputs.size() == n );
00062 
00063   // create triples data, so can sort
00064   vcl_vector<vbl_triple<double,int,int> > data;
00065 
00066   vbl_triple<double,int,int> t;
00067   // add data to triples
00068   for (unsigned int i=0;i<n;++i)
00069   {
00070     t.first=egs(i);
00071     t.second= outputs[i];
00072     t.third = i;
00073     data.push_back(t);
00074   }
00075 
00076   vbl_triple<double,int,int> *data_ptr=&data[0];
00077   vcl_sort(data_ptr,data_ptr+n);
00078   return build_from_sorted_data(classifier, &data[0], wts);
00079 }
00080 
00081 
00082 //: Build a binary_threshold classifier
00083 // nb here egs0 are -ve examples
00084 // and egs1 are +ve examples
00085 double clsfy_binary_threshold_1d_builder::build(clsfy_classifier_1d& classifier,
00086                                                 vnl_vector<double>& egs0,
00087                                                 vnl_vector<double>& wts0,
00088                                                 vnl_vector<double>& egs1,
00089                                                 vnl_vector<double>& wts1)  const
00090 {
00091   // this method sorts the data and passes it to the method below
00092   assert(classifier.is_class("clsfy_binary_threshold_1d"));
00093 
00094   vcl_vector<vbl_triple<double,int,int> > data;
00095   unsigned int n0 = egs0.size();
00096   unsigned int n1 = egs1.size();
00097   vnl_vector<double> wts(n0+n1);
00098   vbl_triple<double,int,int> t;
00099   // add data for class 0
00100   for (unsigned int i=0;i<n0;++i)
00101   {
00102     t.first=egs0[i];
00103     t.second=0;
00104     t.third = i;
00105     wts(i)= wts0[i];
00106     data.push_back(t);
00107   }
00108 
00109   // add data for class 1
00110   for (unsigned int i=0;i<n1;++i)
00111   {
00112     t.first=egs1[i];
00113     t.second=1;
00114     t.third = i+n0;
00115     wts(i+n0)= wts1[i];
00116     data.push_back(t);
00117   }
00118 
00119   unsigned int n=n0+n1;
00120 
00121   vbl_triple<double,int,int> *data_ptr=&data[0];
00122   vcl_sort(data_ptr,data_ptr+n);
00123 
00124   return build_from_sorted_data(classifier,&data[0], wts);
00125 }
00126 
00127 
00128 //: Train classifier, returning weighted error
00129 //   Assumes two classes
00130 double clsfy_binary_threshold_1d_builder::build_from_sorted_data(
00131                                   clsfy_classifier_1d& classifier,
00132                                   const vbl_triple<double,int,int> *data,
00133                                   const vnl_vector<double>& wts
00134                                   ) const
00135 {
00136   // here the triple consists of (value, class number, example index)
00137   // the example index specifies the weight of each example
00138   //
00139   // NB DATA must be sorted for this to work!!!!
00140 
00141 
00142   // calc total weights for class0 and class1 separately
00143   unsigned int n=wts.size();
00144   double tot_wts0=0.0, tot_wts1=0.0;
00145   for (unsigned int i=0;i<n;++i)
00146     if (data[i].second==0)
00147       tot_wts0+=wts(data[i].third);
00148     else
00149       tot_wts1+=wts(data[i].third);
00150 
00151   double e0=0.0, e1=0.0, min_err=2.0;
00152   double etot0,etot1;
00153   unsigned int index=n; int polarity=0;
00154   for (unsigned int i=0;i<n;++i)
00155   {
00156     if (data[i].second==0)
00157       e0+=wts(data[i].third);
00158     else
00159       e1+=wts(data[i].third);
00160 
00161     etot0=(tot_wts0-e0) +e1;
00162     etot1=(tot_wts1-e1) +e0;
00163 
00164     if ( etot0< min_err)
00165     {
00166       // i.e. class1 is maximally separated from class0 at this point
00167       // also members of class1 are generally greater than members of class0
00168       polarity=+1;        //indicates direction of > sign
00169       index=i;            //the threshold
00170 
00171       min_err= etot0;
00172     }
00173 
00174     if ( etot1< min_err)
00175     {
00176       // i.e. class1 is maximally separated from class0 at this point
00177       // also members of class1 are generally less than members of class0
00178       polarity=-1;        //indicates direction of > sign
00179       index=i;            //the threshold
00180 
00181       min_err= etot1;
00182     }
00183   }
00184 
00185   assert ( index!=n );
00186 
00187   // determine threshold from data index
00188   double threshold;
00189   if ( index+1==n )
00190     threshold=data[index].first+0.01;
00191   else
00192     threshold=(data[index].first+data[index+1].first)/2;
00193 
00194   // pass parameters to classifier
00195   vnl_double_2 params(polarity, threshold*polarity);
00196   classifier.set_params(params.as_vector());
00197   return min_err;
00198 }
00199 
00200 //=======================================================================
00201 
00202 vcl_string clsfy_binary_threshold_1d_builder::is_a() const
00203 {
00204   return vcl_string("clsfy_binary_threshold_1d_builder");
00205 }
00206 
00207 bool clsfy_binary_threshold_1d_builder::is_class(vcl_string const& s) const
00208 {
00209   return s == clsfy_binary_threshold_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
00210 }
00211 
00212 //=======================================================================
00213 
00214 clsfy_builder_1d* clsfy_binary_threshold_1d_builder::clone() const
00215 {
00216   return new clsfy_binary_threshold_1d_builder(*this);
00217 }
00218 
00219 //=======================================================================
00220 
00221 // required if data is present in this base class
00222 void clsfy_binary_threshold_1d_builder::print_summary(vcl_ostream& /*os*/) const
00223 {
00224 }
00225 
00226 //=======================================================================
00227 
00228 // required if data is present in this base class
00229 void clsfy_binary_threshold_1d_builder::b_write(vsl_b_ostream& bfs) const
00230 {
00231   short version_no=1;
00232   vsl_b_write(bfs, version_no);
00233 }
00234 
00235 //=======================================================================
00236 
00237   // required if data is present in this base class
00238 void clsfy_binary_threshold_1d_builder::b_read(vsl_b_istream& bfs)
00239 {
00240   if (!bfs) return;
00241 
00242   short version;
00243   vsl_b_read(bfs,version);
00244   switch (version)
00245   {
00246   case 1:
00247     break;
00248   default:
00249     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_builder&)\n"
00250              << "           Unknown version number "<< version << '\n';
00251     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00252     return;
00253   }
00254 }