contrib/mul/clsfy/clsfy_random_forest_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_random_forest_builder.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 // \brief Implement a random_forest classifier builder
00008 // \author Martin Roberts
00009 
00010 #include "clsfy_random_forest_builder.h"
00011 #include <vxl_config.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_string.h>
00014 #include <vcl_algorithm.h>
00015 #include <vcl_numeric.h>
00016 #include <vcl_iterator.h>
00017 #include <vcl_cassert.h>
00018 #include <vsl/vsl_binary_loader.h>
00019 #include <mbl/mbl_stl.h>
00020 #include <mbl/mbl_data_array_wrapper.h>
00021 #include <clsfy/clsfy_binary_tree_builder.h>
00022 #include "clsfy_random_forest.h"
00023 
00024 //=======================================================================
00025 
00026 clsfy_random_forest_builder::clsfy_random_forest_builder()
00027   : ntrees_(100),
00028     max_depth_(-1), min_node_size_(-1),
00029     poob_indices_(0),
00030     calc_test_error_(true)
00031 {
00032     unsigned long default_seed=123654987;
00033     seed_sampler(default_seed);
00034 }
00035 
00036 clsfy_random_forest_builder::clsfy_random_forest_builder(unsigned ntrees,
00037                                                          int max_depth,
00038                                                          int min_node_size)
00039   : ntrees_(ntrees),
00040     max_depth_(max_depth), min_node_size_(min_node_size),
00041     poob_indices_(0),
00042     calc_test_error_(true)
00043 {
00044     unsigned long default_seed=123654987;
00045     seed_sampler(default_seed);
00046 }
00047 
00048 clsfy_random_forest_builder::~clsfy_random_forest_builder()
00049 {
00050 }
00051 //=======================================================================
00052 
00053 short clsfy_random_forest_builder::version_no() const
00054 {
00055     return 1;
00056 }
00057 
00058 //=======================================================================
00059 
00060 vcl_string clsfy_random_forest_builder::is_a() const
00061 {
00062     return vcl_string("clsfy_random_forest_builder");
00063 }
00064 
00065 //=======================================================================
00066 
00067 bool clsfy_random_forest_builder::is_class(vcl_string const& s) const
00068 {
00069     return s == clsfy_random_forest_builder::is_a() || clsfy_builder_base::is_class(s);
00070 }
00071 
00072 //=======================================================================
00073 
00074 clsfy_builder_base* clsfy_random_forest_builder::clone() const
00075 {
00076     return new clsfy_random_forest_builder(*this);
00077 }
00078 
00079 //=======================================================================
00080 
00081 void clsfy_random_forest_builder::print_summary(vcl_ostream& os) const
00082 {
00083     os << "Num trees = "<<ntrees_<<"\tmax_depth = " << max_depth_;
00084 }
00085 
00086 //=======================================================================
00087 
00088 void clsfy_random_forest_builder::b_write(vsl_b_ostream& bfs) const
00089 {
00090     vsl_b_write(bfs, version_no());
00091     vsl_b_write(bfs, ntrees_);
00092     vsl_b_write(bfs, max_depth_);
00093     vsl_b_write(bfs, min_node_size_);
00094     vsl_b_write(bfs,calc_test_error_);
00095     vcl_cerr << "clsfy_random_forest_builder::b_write() NYI\n";
00096 }
00097 
00098 //=======================================================================
00099 
00100 void clsfy_random_forest_builder::b_read(vsl_b_istream& bfs)
00101 {
00102     if (!bfs) return;
00103 
00104     short version;
00105     vsl_b_read(bfs,version);
00106     switch (version)
00107     {
00108         case (1):
00109             vsl_b_read(bfs, ntrees_);
00110             vsl_b_read(bfs, max_depth_);
00111             vsl_b_read(bfs, min_node_size_);
00112             vsl_b_read(bfs,calc_test_error_);
00113             break;
00114         default:
00115             vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_random_forest_builder&)\n"
00116                      << "           Unknown version number "<< version << '\n';
00117             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00118     }
00119 }
00120 
00121 //=======================================================================
00122 
00123 //: Build model from data
00124 // return the mean error over the training set.
00125 // For many classifiers, you may use nClasses==1 to
00126 // indicate a binary classifier
00127 double clsfy_random_forest_builder::build(clsfy_classifier_base& classifier,
00128                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00129                                           unsigned nClasses,
00130                                           const vcl_vector<unsigned> &outputs) const
00131 {
00132     assert(classifier.is_class("clsfy_random_forest")); // equiv to dynamic_cast<> != 0
00133     assert(inputs.size()==outputs.size());
00134     assert(nClasses=1);
00135 
00136 
00137     clsfy_random_forest &random_forest = static_cast<clsfy_random_forest&>(classifier);
00138     unsigned npoints=inputs.size();
00139     vcl_vector<vnl_vector<double> > vin(npoints);
00140 
00141     inputs.reset();
00142     unsigned i=0;
00143     do
00144     {
00145         vin[i++] = inputs.current();
00146     } while (inputs.next());
00147 
00148     assert(i==inputs.size());
00149 
00150     unsigned ndims=vin[0].size();
00151     int nbranch_params=select_nbranch_params(ndims);
00152 
00153     //Start with all parameter indices
00154     vcl_cout<<"npoints= "<<npoints<<"\tndims= "<<ndims<<vcl_endl;
00155     vcl_vector<unsigned> indices(ndims,0);
00156 
00157     mbl_stl_increments(indices.begin(),indices.end(),0);
00158 
00159     //Clean any old trees
00160     random_forest.prune();
00161 
00162     if (poob_indices_)
00163     {
00164         poob_indices_->clear();
00165         poob_indices_->reserve(ntrees_);
00166     }
00167 
00168 
00169     vcl_vector<vnl_vector<double> > bootstrapped_inputs;
00170     vcl_vector<unsigned  > bootstrapped_outputs;
00171 
00172     for (i=0;i<ntrees_;++i)
00173     {
00174         select_data(vin,outputs,bootstrapped_inputs,bootstrapped_outputs);
00175 
00176         clsfy_binary_tree_builder builder;
00177         builder.set_calc_test_error(false);
00178 
00179         clsfy_classifier_base* pBaseClassifier=builder.new_classifier();
00180         clsfy_binary_tree* pTreeClassifier=dynamic_cast<clsfy_binary_tree*>(pBaseClassifier);
00181         assert(pTreeClassifier);
00182         builder.set_nbranch_params(nbranch_params);
00183 
00184         unsigned long seed=get_tree_builder_seed();
00185 //        vcl_cout<<"The seed is "<<seed<<vcl_endl;
00186         builder.seed_sampler(seed);
00187 
00188         builder.set_max_depth(max_depth_);
00189         builder.set_min_node_size(min_node_size_);
00190         mbl_data_array_wrapper<vnl_vector<double> > bootstrapped_inputs_mbl(bootstrapped_inputs);
00191 
00192         builder.build(*pTreeClassifier,
00193                       bootstrapped_inputs_mbl,
00194                       1,
00195                       bootstrapped_outputs);
00196 
00197         mbl_cloneable_ptr<clsfy_classifier_base> treeClassifier(pTreeClassifier);
00198         random_forest.trees_.push_back(treeClassifier);
00199     }
00200 
00201     if (calc_test_error_)
00202         return clsfy_test_error(classifier, inputs, outputs);
00203     else
00204         return 0.0;
00205 }
00206 //=======================================================================
00207 //: Create empty classifier
00208 // Caller is responsible for deletion
00209 clsfy_classifier_base* clsfy_random_forest_builder::new_classifier() const
00210 {
00211     return new clsfy_random_forest();
00212 }
00213 
00214 
00215 void clsfy_random_forest_builder::select_data(vcl_vector<vnl_vector<double> >& inputs,
00216                                               const vcl_vector<unsigned> &outputs,
00217                                               vcl_vector<vnl_vector<double> >& bootstrapped_inputs,
00218                                               vcl_vector<unsigned> & bootstrapped_outputs) const
00219 {
00220     unsigned npoints=inputs.size();
00221     bootstrapped_inputs.resize(npoints);
00222     bootstrapped_outputs.resize(npoints);
00223     unsigned ndims=  inputs.front().size();
00224     if (poob_indices_)
00225     {
00226         poob_indices_->push_back(vcl_vector<unsigned>());
00227         poob_indices_->back().reserve(npoints);
00228     }
00229     for (unsigned i=0;i<npoints;++i)
00230     {
00231         bootstrapped_inputs[i].set_size(ndims);
00232         unsigned index=random_sampler_(npoints);
00233         bootstrapped_inputs[i]=inputs[index];
00234         bootstrapped_outputs[i]=outputs[index];
00235         if (poob_indices_)
00236             poob_indices_->back().push_back(index); //store index of point for later OOB estimates
00237     }
00238 }
00239 
00240 unsigned  clsfy_random_forest_builder::select_nbranch_params(unsigned ndims) const
00241 {
00242     unsigned nbranch_params=1;
00243     if (ndims>2)
00244     {
00245         double dnbranch_params=vcl_sqrt(double(ndims)+0.1); //round up if close
00246         nbranch_params=unsigned (dnbranch_params); //round
00247     }
00248     return nbranch_params;
00249 }
00250 
00251 void clsfy_random_forest_builder::seed_sampler(unsigned long seed)
00252 {
00253     random_sampler_.reseed(seed);
00254 }
00255 
00256 unsigned long clsfy_random_forest_builder::get_tree_builder_seed() const
00257 {
00258     //generate some bytes from the original seeded random number generator
00259     unsigned long N=256;
00260     unsigned nbytes=sizeof(unsigned long);
00261     vcl_vector<vxl_byte> seedAsBytes(nbytes,1);
00262 
00263     for (unsigned ib=0;ib<nbytes;++ib)
00264     {
00265         seedAsBytes[ib]=static_cast<vxl_byte>(random_sampler_(N));
00266     }
00267 
00268     unsigned long* pSeed=reinterpret_cast<unsigned long*>(&seedAsBytes[0]);
00269     return *pSeed;
00270 }