contrib/mul/clsfy/clsfy_random_forest.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_random_forest.cxx
00002 #include "clsfy_random_forest.h"
00003 //:
00004 // \file
00005 // \brief Random forest classifier
00006 // \author Martin Roberts
00007 
00008 #include <vcl_string.h>
00009 #include <vcl_deque.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_iterator.h>
00012 #include <vcl_cmath.h>
00013 #include <vcl_cassert.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_vector_io.h>
00016 #include <vnl/io/vnl_io_vector.h>
00017 #include <mbl/mbl_cloneable_ptr.h>
00018 
00019 
00020 clsfy_random_forest::clsfy_random_forest()
00021 {
00022 }
00023 
00024 //=======================================================================
00025 //: Return the classification of the given probe vector.
00026 unsigned clsfy_random_forest::classify(const vnl_vector<double> &input) const
00027 {
00028 #if 1 //Accumulate probabilities (impure final nodes may not return 0 or 1)
00029     vcl_vector<double > classProbs(1,0.0);
00030     class_probabilities(classProbs,input);
00031     return (classProbs[0]>=0.5) ? 1 : 0;
00032 #else // just accumulate number in each class rather than probs
00033 
00034     vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
00035     vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
00036 
00037     vcl_vector<unsigned > classCount(2,0);
00038 
00039     unsigned i=0;
00040     while (treeIter != treeIterEnd)
00041     {
00042         mbl_cloneable_ptr<clsfy_classifier_base> pTree=*treeIter++;
00043         unsigned treeClass= pTree->classify(input);
00044 
00045         ++classCount[treeClass];
00046     }
00047     if (classCount[0] >= classCount[1])
00048         return 0;
00049     else
00050         return 1;
00051 #endif // 1
00052 }
00053 
00054 //=======================================================================
00055 //: Return a probability like value that the input being in each class.
00056 // output(i) i<<nClasses, contains the probability that the input is in class i
00057 void clsfy_random_forest::class_probabilities(vcl_vector<double>& outputs,
00058                                               vnl_vector<double>const& input) const
00059 {
00060     outputs.resize(1);
00061 
00062     vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
00063     vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
00064 
00065     vcl_vector<double > classProbs(1,0.0);
00066     vcl_vector<double > meanProbs(1,0.0);
00067 
00068     while (treeIter != treeIterEnd)
00069     {
00070         const clsfy_classifier_base* pTree=(*treeIter).ptr();
00071         pTree->class_probabilities(classProbs, input);
00072         meanProbs[0]+=classProbs[0];
00073         ++treeIter;
00074     }
00075     outputs[0]=meanProbs[0]/double (trees_.size());
00076 }
00077 
00078 
00079 //=======================================================================
00080 //: This value has properties of a Log likelihood of being in class (binary classifiers only)
00081 // class probability = exp(logL) / (1+exp(logL))
00082 double clsfy_random_forest::log_l(const vnl_vector<double> &input) const
00083 {
00084     //Retain logistic function relation to prob
00085     //i.e. invert the above relation
00086     double epsilon=1.0E-8;
00087     vcl_vector<double > probs(1,0.5);
00088     class_probabilities(probs,input);
00089     double p=probs[0];
00090     double d=(1.0/p)-1.0;
00091     double x=1.0;
00092     if (d>epsilon)
00093         x=-vcl_log(d);
00094     else
00095         x=-vcl_log(epsilon);
00096 
00097     return x;
00098 }
00099 
00100 //======================= Out of Bag add-ons ==============================
00101 void clsfy_random_forest::class_probabilities_oob(vcl_vector<double> &outputs,
00102                                                   const vnl_vector<double> &input,
00103                                                   const vcl_vector<vcl_vector<unsigned > >& oobIndices,
00104                                                   unsigned this_index) const
00105 {
00106     outputs.resize(1);
00107 
00108     vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
00109     vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
00110 
00111     vcl_vector<double > classProbs(1,0.0);
00112     vcl_vector<double > meanProbs(1,0.0);
00113     vcl_vector<vcl_vector<unsigned > >::const_iterator oobIndexIter=oobIndices.begin() ;
00114     unsigned noob=0;
00115     while (treeIter != treeIterEnd)
00116     {
00117         if (vcl_find(oobIndexIter->begin(),oobIndexIter->end(),this_index)==oobIndexIter->end())
00118         {
00119             //Not found this_index, so Out of Bag - accumulate this tree's vote
00120             const clsfy_classifier_base* pTree=(*treeIter).ptr();
00121 
00122             pTree->class_probabilities(classProbs, input);
00123             meanProbs[0]+=classProbs[0];
00124             ++noob;
00125         }
00126         ++treeIter;
00127         ++oobIndexIter;
00128     }
00129     outputs[0]=meanProbs[0]/double (noob);
00130 }
00131 
00132 
00133 //: Return the classification of the given probe vector using out of bag trees only.
00134 // See also class_probabilities_oob
00135 unsigned clsfy_random_forest::classify_oob(const vnl_vector<double> &input,
00136                                            const vcl_vector<vcl_vector<unsigned > >& oobIndices,
00137                                            unsigned this_index) const
00138 {
00139     vcl_vector<double > classProbs(1,0.0);
00140     class_probabilities_oob(classProbs,input,oobIndices,this_index);
00141     return (classProbs[0]>=0.5) ? 1 : 0;
00142 }
00143 
00144 
00145 //=======================================================================
00146 
00147 vcl_string clsfy_random_forest::is_a() const
00148 {
00149     return vcl_string("clsfy_random_forest");
00150 }
00151 
00152 //=======================================================================
00153 
00154 bool clsfy_random_forest::is_class(vcl_string const& s) const
00155 {
00156     return s == clsfy_random_forest::is_a() || clsfy_classifier_base::is_class(s);
00157 }
00158 
00159 //=======================================================================
00160 
00161 short clsfy_random_forest::version_no() const
00162 {
00163     return 1;
00164 }
00165 
00166 //=======================================================================
00167 
00168 clsfy_classifier_base* clsfy_random_forest::clone() const
00169 {
00170     return new clsfy_random_forest(*this);
00171 }
00172 
00173 //=======================================================================
00174 
00175 void clsfy_random_forest::print_summary(vcl_ostream& os) const
00176 {
00177     os<<"clsfy_random_forest\t has "<<trees_.size()<<" trees"<<vcl_endl;
00178 }
00179 
00180 //=======================================================================
00181 
00182 void clsfy_random_forest::b_write(vsl_b_ostream& bfs) const
00183 {
00184     vcl_cout<<"clsfy_random_forest::b_write"<<vcl_endl;
00185     vsl_b_write(bfs,version_no());
00186     unsigned n=trees_.size();
00187     vsl_b_write(bfs,n);
00188     for (unsigned i=0; i<n;++i)
00189     {
00190         trees_[i]->b_write(bfs);
00191     }
00192 }
00193 
00194 //=======================================================================
00195 
00196 void clsfy_random_forest::b_read(vsl_b_istream& bfs)
00197 {
00198     if (!bfs) return;
00199 
00200     prune();
00201     short version;
00202     vsl_b_read(bfs,version);
00203     switch (version)
00204     {
00205         case 1:
00206         {
00207             unsigned n;
00208             vsl_b_read(bfs,n);
00209             vcl_cout<<"Am attemptig to read in "<<n<<"\t trees"<<vcl_endl;
00210             trees_.reserve(n);
00211             for (unsigned i=0; i<n;++i)
00212             {
00213 //                vcl_cout<<"reading tree "<<i<<vcl_endl;
00214                 mbl_cloneable_ptr< clsfy_classifier_base> tree(new clsfy_binary_tree);
00215                 trees_.push_back(tree);
00216                 trees_.back()->b_read(bfs);
00217             }
00218             break;
00219         }
00220 
00221         default:
00222             vcl_cerr << "I/O ERROR: clsfy_random_forest::b_read(vsl_b_istream&)\n"
00223                      << "           Unknown version number "<< version << '\n';
00224             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00225     }
00226 }
00227 
00228 clsfy_random_forest::~clsfy_random_forest()
00229 {
00230     prune();
00231 }
00232 
00233 void clsfy_random_forest::prune()
00234 {
00235     trees_.clear(); //note mbl wrapper destructor deletes the tree pointer!
00236 }
00237 
00238 //=======================================================================
00239 //: The dimensionality of input vectors.
00240 unsigned clsfy_random_forest::n_dims() const
00241 {
00242     if (trees_.empty())
00243         return 0;
00244     else
00245         return trees_.front()->n_dims();
00246 }
00247 
00248 clsfy_random_forest& clsfy_random_forest::operator+=(const clsfy_random_forest& forest2)
00249 {
00250     this->trees_.reserve(this->trees_.size()+forest2.trees_.size());
00251     this->trees_.insert(this->trees_.end(),
00252                         forest2.trees_.begin(),forest2.trees_.end());
00253     return *this;
00254 }
00255 
00256 
00257 //============ Friend functions for merging stuff ====================
00258 
00259 //: Merge the sub-forests in the input filenames into a single larger one
00260 void merge_sub_forests(const vcl_vector<vcl_string>& filenames,
00261                        clsfy_random_forest& large_forest)
00262 {
00263     vcl_vector<vcl_string>::const_iterator fileIter=filenames.begin();
00264     vcl_vector<vcl_string>::const_iterator fileIterEnd=filenames.end();
00265     while (fileIter != fileIterEnd)
00266     {
00267         vsl_b_ifstream bfs_in(*fileIter);
00268         clsfy_random_forest subForest;
00269         vsl_b_read(bfs_in, subForest);
00270         bfs_in.close();
00271         large_forest.trees_.reserve(large_forest.trees_.size()+subForest.trees_.size());
00272         large_forest.trees_.insert(large_forest.trees_.end(),
00273                                    subForest.trees_.begin(),subForest.trees_.end());
00274         ++fileIter;
00275     }
00276 }
00277 
00278 //: Merge the sub-forests pointed to the input vector a single larger one
00279 void merge_sub_forests(const vcl_vector< clsfy_random_forest*>& sub_forests,
00280                        clsfy_random_forest& large_forest)
00281 {
00282     vcl_vector<clsfy_random_forest*>::const_iterator subForestIter=sub_forests.begin();
00283     vcl_vector<clsfy_random_forest*>::const_iterator subForestIterEnd=sub_forests.end();
00284     while (subForestIter != subForestIterEnd)
00285     {
00286         const clsfy_random_forest& subForest=**subForestIter;
00287         large_forest.trees_.reserve(large_forest.trees_.size()+subForest.trees_.size());
00288         large_forest.trees_.insert(large_forest.trees_.end(),
00289                                    subForest.trees_.begin(),subForest.trees_.end());
00290         ++subForestIter;
00291     }
00292 }
00293 
00294 //: Merge the two input forests
00295 clsfy_random_forest operator+(const clsfy_random_forest& forest1,
00296                               const clsfy_random_forest& forest2)
00297 {
00298     clsfy_random_forest mergedForest=forest1;
00299 
00300     mergedForest.trees_.reserve(forest1.trees_.size()+forest2.trees_.size());
00301     mergedForest.trees_.insert(mergedForest.trees_.end(),
00302                                forest2.trees_.begin(),forest2.trees_.end());
00303     return mergedForest;
00304 }
00305