contrib/mul/clsfy/clsfy_binary_threshold_1d_gini_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_threshold_1d_gini_builder.cxx
00002 #include "clsfy_binary_threshold_1d_gini_builder.h"
00003 //:
00004 // \file
00005 // \author Martin Roberts
00006 
00007 #include <vcl_iostream.h>
00008 #include <vcl_string.h>
00009 #include <vcl_cassert.h>
00010 #include <vsl/vsl_binary_loader.h>
00011 #include <vnl/vnl_double_2.h>
00012 #include <clsfy/clsfy_builder_1d.h>
00013 #include <clsfy/clsfy_binary_threshold_1d.h>
00014 #include <vcl_algorithm.h>
00015 
00016 // Note this is used by clsfy_binary_tree_builder
00017 // Derived from clsfy_binary_threshold_1d_builder but uses a slightly different
00018 // interface to do the gini index optimisation, as tis returns the reduction
00019 // in the gini impurity (not classification error).
00020 
00021 //=======================================================================
00022 
00023 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder()
00024 {
00025 }
00026 
00027 //=======================================================================
00028 
00029 clsfy_binary_threshold_1d_gini_builder::~clsfy_binary_threshold_1d_gini_builder()
00030 {
00031 }
00032 
00033 //=======================================================================
00034 
00035 short clsfy_binary_threshold_1d_gini_builder::version_no() const
00036 {
00037     return 1;
00038 }
00039 
00040 
00041 //: Create empty classifier
00042 // Caller is responsible for deletion
00043 clsfy_classifier_1d* clsfy_binary_threshold_1d_gini_builder::new_classifier() const
00044 {
00045     return new clsfy_binary_threshold_1d();
00046 }
00047 
00048 //: Build a binary_threshold classifier
00049 //  Train classifier
00050 //  Selects parameters of classifier which best separate examples from two classes,
00051 // Uses the gini impurity index
00052 // Note it returns the -reduction in Gini impurity produced by the split
00053 // Not the misclassification rate
00054 // (i.e. but minimise as per error rate)
00055 double clsfy_binary_threshold_1d_gini_builder::build_gini(clsfy_classifier_1d& classifier,
00056                                                           const vnl_vector<double>&  inputs,
00057                                                           const vcl_vector<unsigned> &outputs) const
00058 {
00059     assert(classifier.is_class("clsfy_binary_threshold_1d"));
00060 
00061     unsigned n = inputs.size();
00062     assert ( outputs.size() == n );
00063 
00064     // create triples data, so can sort
00065     vcl_vector<vbl_triple<double,int,int> > data;
00066     data.reserve(n);
00067 
00068     //First just create sorted data
00069     vcl_vector<unsigned >::const_iterator classIter=outputs.begin();
00070     vnl_vector<double  >::const_iterator inputIter=inputs.begin();
00071     vnl_vector<double  >::const_iterator inputIterEnd=inputs.end();
00072     vbl_triple<double,int,int> t;
00073     unsigned i=0;
00074     while (inputIter != inputIterEnd)
00075     {
00076         t.first = *inputIter++;
00077         t.second=*classIter++;
00078         t.third = i++;
00079         data.push_back(t);
00080     }
00081 
00082     assert(i==inputs.size());
00083 
00084     vcl_sort(data.begin(),data.end());
00085     return build_gini_from_sorted_data(static_cast<clsfy_classifier_1d&>(classifier), data);
00086 }
00087 
00088 
00089 //: Train classifier, returning weighted error
00090 //   Assumes two classes
00091 //  Note that input "data" must be sorted to use this routine
00092 //Return -improvement in impurity (as normally these builders minimise)
00093 double clsfy_binary_threshold_1d_gini_builder::build_gini_from_sorted_data(
00094     clsfy_classifier_1d& classifier,
00095     const vcl_vector<vbl_triple<double,int,int> >& data) const
00096 {
00097     // here the triple consists of (value, class number, example index)
00098     // the example index specifies the weight of each example
00099     //
00100     // NB DATA must be sorted for this to work!!!!
00101 
00102     //Validate that the data is not homogeneous
00103     const double epsilon=1.0E-20;
00104     if (vcl_fabs(data.front().first-data.back().first)<epsilon)
00105     {
00106         vcl_cerr<<"WARNING - clsfy_binary_threshold_1d_gini_builder::build_from_sorted_data - homogeneous data - cannot split\n";
00107         int polarity=1;
00108         double threshold=data[0].first;
00109         vnl_double_2 params(polarity, threshold*polarity);
00110         classifier.set_params(params.as_vector());
00111         return 1.0;
00112     }
00113 
00114     unsigned int ntot=data.size();
00115     double dntot=double (ntot);
00116     vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIter=data.begin();
00117     vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterEnd=data.end();
00118     unsigned n0Tot=0;
00119     unsigned n1Tot=0;
00120     while (dataIter != dataIterEnd)
00121     {
00122         if (dataIter->second==0)
00123             ++n0Tot;
00124         else
00125             ++n1Tot;
00126         ++dataIter;
00127     }
00128 
00129     double parentImp=0.0;
00130     //Parent level impurity to start with
00131 
00132     double p=double (n0Tot)/dntot;
00133     parentImp=2.0*p*(1-p);
00134 
00135     dataIter=data.begin();
00136     double s=dataIter->first-epsilon;
00137     double deltaImpBest= -1.0; //initialise to split makes it worse
00138     double  sbest=s;
00139     //Put none into left bin, all else go right
00140     unsigned nL0=0;
00141     unsigned nL1=0;
00142     unsigned nR0=n0Tot;
00143     unsigned nR1=n1Tot;
00144     double parity=1.0;
00145     while (dataIter != dataIterEnd)
00146     {
00147         s=dataIter->first;
00148         vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterNext=dataIter;
00149 
00150         //Increment till threshold increases (may have some same data values)
00151         while (dataIterNext != dataIterEnd && (dataIterNext->first-s)<epsilon)
00152         {
00153             if (dataIterNext->second==0)
00154             {
00155                 ++nL0;
00156                 --nR0;
00157             }
00158             else
00159             {
00160                 ++nL1;
00161                 --nR1;
00162             }
00163             ++dataIterNext;
00164         }
00165 
00166         unsigned nLTot=nL0+nL1;
00167         unsigned nRTot=nR0+nR1;
00168         double probL=double(nL0)/double(nLTot);
00169         double probR=double(nR1)/double(nRTot);
00170         //Two-class Gini index for left and right
00171         double impL=2.0*probL*(1-probL);
00172         double impR=2.0*probR*(1-probR);
00173 
00174         //Proportional weights
00175         double pL=double (nLTot)/dntot;
00176         double pR=1.0-pL;
00177 
00178         double deltaImp=parentImp-(pL*impL + pR*impR);
00179         if (deltaImp>deltaImpBest)
00180         {
00181             deltaImpBest=deltaImp;
00182             sbest=s;
00183             if (nR1>=nL1) //More class 1 are going above thresh
00184                 parity=1;
00185             else
00186                 parity=-1; //Reverse sign as more class one are going below thresh
00187         }
00188 
00189         dataIter=dataIterNext;
00190     }
00191 
00192     double threshold=sbest;
00193 
00194     // pass parameters to classifier
00195     vnl_double_2 params(parity, threshold*parity);
00196     classifier.set_params(params.as_vector());
00197     return -deltaImpBest;
00198 }
00199 
00200 //=======================================================================
00201 
00202 vcl_string clsfy_binary_threshold_1d_gini_builder::is_a() const
00203 {
00204   return vcl_string("clsfy_binary_threshold_1d_gini_builder");
00205 }
00206 
00207 bool clsfy_binary_threshold_1d_gini_builder::is_class(vcl_string const& s) const
00208 {
00209   return s == clsfy_binary_threshold_1d_gini_builder::is_a() || clsfy_builder_1d::is_class(s);
00210 }
00211 
00212 # if 0
00213 //=======================================================================
00214 
00215 
00216 // required if data stored on the heap is present in this derived class
00217 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder(
00218                              const clsfy_binary_threshold_1d_gini_builder& new_b) :
00219   data_ptr_(0)
00220 {
00221   *this = new_b;
00222 }
00223 
00224 //=======================================================================
00225 
00226 // required if data stored on the heap is present in this derived class
00227 clsfy_binary_threshold_1d_gini_builder&
00228 clsfy_binary_threshold_1d_gini_builder::operator=(const clsfy_binary_threshold_1d_gini_builder& new_b)
00229 {
00230     if (&new_b==this) return *this;
00231 
00232     static_cast<clsfy_binary_threshold_1d_builder&>(*this)=
00233         static_cast<const clsfy_binary_threshold_1d_builder&> (new_b);
00234 
00235     return *this;
00236 }
00237 #endif
00238 //=======================================================================
00239 
00240 // required if data is present in this base class
00241 void clsfy_binary_threshold_1d_gini_builder::print_summary(vcl_ostream& os) const
00242 {
00243     os<<"clsfy_binary_threshold_1d_gini_builder"<<vcl_endl;
00244 }
00245 
00246 //=======================================================================
00247 
00248 
00249 clsfy_builder_1d* clsfy_binary_threshold_1d_gini_builder::clone() const
00250 {
00251     return new clsfy_binary_threshold_1d_gini_builder(*this);
00252 }
00253 //=======================================================================
00254 
00255 // required if data is present in this base class
00256 void clsfy_binary_threshold_1d_gini_builder::b_write(vsl_b_ostream& bfs) const
00257 {
00258     vsl_b_write(bfs, version_no());
00259     clsfy_binary_threshold_1d_builder::b_write(bfs);
00260 }
00261 
00262 //=======================================================================
00263 
00264   // required if data is present in this base class
00265 void clsfy_binary_threshold_1d_gini_builder::b_read(vsl_b_istream& bfs)
00266 {
00267     if (!bfs) return;
00268 
00269     short version;
00270     vsl_b_read(bfs,version);
00271     switch (version)
00272     {
00273         case (1):
00274             clsfy_binary_threshold_1d_builder::b_read(bfs);
00275             break;
00276         default:
00277             vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_gini_builder&)\n"
00278                      << "           Unknown version number "<< version << '\n';
00279             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00280             return;
00281     }
00282 }