Go to the documentation of this file.00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010 #include "clsfy_random_forest_builder.h"
00011 #include <vxl_config.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_string.h>
00014 #include <vcl_algorithm.h>
00015 #include <vcl_numeric.h>
00016 #include <vcl_iterator.h>
00017 #include <vcl_cassert.h>
00018 #include <vsl/vsl_binary_loader.h>
00019 #include <mbl/mbl_stl.h>
00020 #include <mbl/mbl_data_array_wrapper.h>
00021 #include <clsfy/clsfy_binary_tree_builder.h>
00022 #include "clsfy_random_forest.h"
00023
00024
00025
00026 clsfy_random_forest_builder::clsfy_random_forest_builder()
00027 : ntrees_(100),
00028 max_depth_(-1), min_node_size_(-1),
00029 poob_indices_(0),
00030 calc_test_error_(true)
00031 {
00032 unsigned long default_seed=123654987;
00033 seed_sampler(default_seed);
00034 }
00035
00036 clsfy_random_forest_builder::clsfy_random_forest_builder(unsigned ntrees,
00037 int max_depth,
00038 int min_node_size)
00039 : ntrees_(ntrees),
00040 max_depth_(max_depth), min_node_size_(min_node_size),
00041 poob_indices_(0),
00042 calc_test_error_(true)
00043 {
00044 unsigned long default_seed=123654987;
00045 seed_sampler(default_seed);
00046 }
00047
00048 clsfy_random_forest_builder::~clsfy_random_forest_builder()
00049 {
00050 }
00051
00052
00053 short clsfy_random_forest_builder::version_no() const
00054 {
00055 return 1;
00056 }
00057
00058
00059
00060 vcl_string clsfy_random_forest_builder::is_a() const
00061 {
00062 return vcl_string("clsfy_random_forest_builder");
00063 }
00064
00065
00066
00067 bool clsfy_random_forest_builder::is_class(vcl_string const& s) const
00068 {
00069 return s == clsfy_random_forest_builder::is_a() || clsfy_builder_base::is_class(s);
00070 }
00071
00072
00073
00074 clsfy_builder_base* clsfy_random_forest_builder::clone() const
00075 {
00076 return new clsfy_random_forest_builder(*this);
00077 }
00078
00079
00080
00081 void clsfy_random_forest_builder::print_summary(vcl_ostream& os) const
00082 {
00083 os << "Num trees = "<<ntrees_<<"\tmax_depth = " << max_depth_;
00084 }
00085
00086
00087
00088 void clsfy_random_forest_builder::b_write(vsl_b_ostream& bfs) const
00089 {
00090 vsl_b_write(bfs, version_no());
00091 vsl_b_write(bfs, ntrees_);
00092 vsl_b_write(bfs, max_depth_);
00093 vsl_b_write(bfs, min_node_size_);
00094 vsl_b_write(bfs,calc_test_error_);
00095 vcl_cerr << "clsfy_random_forest_builder::b_write() NYI\n";
00096 }
00097
00098
00099
00100 void clsfy_random_forest_builder::b_read(vsl_b_istream& bfs)
00101 {
00102 if (!bfs) return;
00103
00104 short version;
00105 vsl_b_read(bfs,version);
00106 switch (version)
00107 {
00108 case (1):
00109 vsl_b_read(bfs, ntrees_);
00110 vsl_b_read(bfs, max_depth_);
00111 vsl_b_read(bfs, min_node_size_);
00112 vsl_b_read(bfs,calc_test_error_);
00113 break;
00114 default:
00115 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_random_forest_builder&)\n"
00116 << " Unknown version number "<< version << '\n';
00117 bfs.is().clear(vcl_ios::badbit);
00118 }
00119 }
00120
00121
00122
00123
00124
00125
00126
00127 double clsfy_random_forest_builder::build(clsfy_classifier_base& classifier,
00128 mbl_data_wrapper<vnl_vector<double> >& inputs,
00129 unsigned nClasses,
00130 const vcl_vector<unsigned> &outputs) const
00131 {
00132 assert(classifier.is_class("clsfy_random_forest"));
00133 assert(inputs.size()==outputs.size());
00134 assert(nClasses=1);
00135
00136
00137 clsfy_random_forest &random_forest = static_cast<clsfy_random_forest&>(classifier);
00138 unsigned npoints=inputs.size();
00139 vcl_vector<vnl_vector<double> > vin(npoints);
00140
00141 inputs.reset();
00142 unsigned i=0;
00143 do
00144 {
00145 vin[i++] = inputs.current();
00146 } while (inputs.next());
00147
00148 assert(i==inputs.size());
00149
00150 unsigned ndims=vin[0].size();
00151 int nbranch_params=select_nbranch_params(ndims);
00152
00153
00154 vcl_cout<<"npoints= "<<npoints<<"\tndims= "<<ndims<<vcl_endl;
00155 vcl_vector<unsigned> indices(ndims,0);
00156
00157 mbl_stl_increments(indices.begin(),indices.end(),0);
00158
00159
00160 random_forest.prune();
00161
00162 if (poob_indices_)
00163 {
00164 poob_indices_->clear();
00165 poob_indices_->reserve(ntrees_);
00166 }
00167
00168
00169 vcl_vector<vnl_vector<double> > bootstrapped_inputs;
00170 vcl_vector<unsigned > bootstrapped_outputs;
00171
00172 for (i=0;i<ntrees_;++i)
00173 {
00174 select_data(vin,outputs,bootstrapped_inputs,bootstrapped_outputs);
00175
00176 clsfy_binary_tree_builder builder;
00177 builder.set_calc_test_error(false);
00178
00179 clsfy_classifier_base* pBaseClassifier=builder.new_classifier();
00180 clsfy_binary_tree* pTreeClassifier=dynamic_cast<clsfy_binary_tree*>(pBaseClassifier);
00181 assert(pTreeClassifier);
00182 builder.set_nbranch_params(nbranch_params);
00183
00184 unsigned long seed=get_tree_builder_seed();
00185
00186 builder.seed_sampler(seed);
00187
00188 builder.set_max_depth(max_depth_);
00189 builder.set_min_node_size(min_node_size_);
00190 mbl_data_array_wrapper<vnl_vector<double> > bootstrapped_inputs_mbl(bootstrapped_inputs);
00191
00192 builder.build(*pTreeClassifier,
00193 bootstrapped_inputs_mbl,
00194 1,
00195 bootstrapped_outputs);
00196
00197 mbl_cloneable_ptr<clsfy_classifier_base> treeClassifier(pTreeClassifier);
00198 random_forest.trees_.push_back(treeClassifier);
00199 }
00200
00201 if (calc_test_error_)
00202 return clsfy_test_error(classifier, inputs, outputs);
00203 else
00204 return 0.0;
00205 }
00206
00207
00208
00209 clsfy_classifier_base* clsfy_random_forest_builder::new_classifier() const
00210 {
00211 return new clsfy_random_forest();
00212 }
00213
00214
00215 void clsfy_random_forest_builder::select_data(vcl_vector<vnl_vector<double> >& inputs,
00216 const vcl_vector<unsigned> &outputs,
00217 vcl_vector<vnl_vector<double> >& bootstrapped_inputs,
00218 vcl_vector<unsigned> & bootstrapped_outputs) const
00219 {
00220 unsigned npoints=inputs.size();
00221 bootstrapped_inputs.resize(npoints);
00222 bootstrapped_outputs.resize(npoints);
00223 unsigned ndims= inputs.front().size();
00224 if (poob_indices_)
00225 {
00226 poob_indices_->push_back(vcl_vector<unsigned>());
00227 poob_indices_->back().reserve(npoints);
00228 }
00229 for (unsigned i=0;i<npoints;++i)
00230 {
00231 bootstrapped_inputs[i].set_size(ndims);
00232 unsigned index=random_sampler_(npoints);
00233 bootstrapped_inputs[i]=inputs[index];
00234 bootstrapped_outputs[i]=outputs[index];
00235 if (poob_indices_)
00236 poob_indices_->back().push_back(index);
00237 }
00238 }
00239
00240 unsigned clsfy_random_forest_builder::select_nbranch_params(unsigned ndims) const
00241 {
00242 unsigned nbranch_params=1;
00243 if (ndims>2)
00244 {
00245 double dnbranch_params=vcl_sqrt(double(ndims)+0.1);
00246 nbranch_params=unsigned (dnbranch_params);
00247 }
00248 return nbranch_params;
00249 }
00250
00251 void clsfy_random_forest_builder::seed_sampler(unsigned long seed)
00252 {
00253 random_sampler_.reseed(seed);
00254 }
00255
00256 unsigned long clsfy_random_forest_builder::get_tree_builder_seed() const
00257 {
00258
00259 unsigned long N=256;
00260 unsigned nbytes=sizeof(unsigned long);
00261 vcl_vector<vxl_byte> seedAsBytes(nbytes,1);
00262
00263 for (unsigned ib=0;ib<nbytes;++ib)
00264 {
00265 seedAsBytes[ib]=static_cast<vxl_byte>(random_sampler_(N));
00266 }
00267
00268 unsigned long* pSeed=reinterpret_cast<unsigned long*>(&seedAsBytes[0]);
00269 return *pSeed;
00270 }