Go to the documentation of this file.00001
00002 #include "clsfy_random_forest.h"
00003
00004
00005
00006
00007
00008 #include <vcl_string.h>
00009 #include <vcl_deque.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_iterator.h>
00012 #include <vcl_cmath.h>
00013 #include <vcl_cassert.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_vector_io.h>
00016 #include <vnl/io/vnl_io_vector.h>
00017 #include <mbl/mbl_cloneable_ptr.h>
00018
00019
00020 clsfy_random_forest::clsfy_random_forest()
00021 {
00022 }
00023
00024
00025
00026 unsigned clsfy_random_forest::classify(const vnl_vector<double> &input) const
00027 {
00028 #if 1 //Accumulate probabilities (impure final nodes may not return 0 or 1)
00029 vcl_vector<double > classProbs(1,0.0);
00030 class_probabilities(classProbs,input);
00031 return (classProbs[0]>=0.5) ? 1 : 0;
00032 #else // just accumulate number in each class rather than probs
00033
00034 vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
00035 vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
00036
00037 vcl_vector<unsigned > classCount(2,0);
00038
00039 unsigned i=0;
00040 while (treeIter != treeIterEnd)
00041 {
00042 mbl_cloneable_ptr<clsfy_classifier_base> pTree=*treeIter++;
00043 unsigned treeClass= pTree->classify(input);
00044
00045 ++classCount[treeClass];
00046 }
00047 if (classCount[0] >= classCount[1])
00048 return 0;
00049 else
00050 return 1;
00051 #endif // 1
00052 }
00053
00054
00055
00056
00057 void clsfy_random_forest::class_probabilities(vcl_vector<double>& outputs,
00058 vnl_vector<double>const& input) const
00059 {
00060 outputs.resize(1);
00061
00062 vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
00063 vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
00064
00065 vcl_vector<double > classProbs(1,0.0);
00066 vcl_vector<double > meanProbs(1,0.0);
00067
00068 while (treeIter != treeIterEnd)
00069 {
00070 const clsfy_classifier_base* pTree=(*treeIter).ptr();
00071 pTree->class_probabilities(classProbs, input);
00072 meanProbs[0]+=classProbs[0];
00073 ++treeIter;
00074 }
00075 outputs[0]=meanProbs[0]/double (trees_.size());
00076 }
00077
00078
00079
00080
00081
00082 double clsfy_random_forest::log_l(const vnl_vector<double> &input) const
00083 {
00084
00085
00086 double epsilon=1.0E-8;
00087 vcl_vector<double > probs(1,0.5);
00088 class_probabilities(probs,input);
00089 double p=probs[0];
00090 double d=(1.0/p)-1.0;
00091 double x=1.0;
00092 if (d>epsilon)
00093 x=-vcl_log(d);
00094 else
00095 x=-vcl_log(epsilon);
00096
00097 return x;
00098 }
00099
00100
00101 void clsfy_random_forest::class_probabilities_oob(vcl_vector<double> &outputs,
00102 const vnl_vector<double> &input,
00103 const vcl_vector<vcl_vector<unsigned > >& oobIndices,
00104 unsigned this_index) const
00105 {
00106 outputs.resize(1);
00107
00108 vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
00109 vcl_vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
00110
00111 vcl_vector<double > classProbs(1,0.0);
00112 vcl_vector<double > meanProbs(1,0.0);
00113 vcl_vector<vcl_vector<unsigned > >::const_iterator oobIndexIter=oobIndices.begin() ;
00114 unsigned noob=0;
00115 while (treeIter != treeIterEnd)
00116 {
00117 if (vcl_find(oobIndexIter->begin(),oobIndexIter->end(),this_index)==oobIndexIter->end())
00118 {
00119
00120 const clsfy_classifier_base* pTree=(*treeIter).ptr();
00121
00122 pTree->class_probabilities(classProbs, input);
00123 meanProbs[0]+=classProbs[0];
00124 ++noob;
00125 }
00126 ++treeIter;
00127 ++oobIndexIter;
00128 }
00129 outputs[0]=meanProbs[0]/double (noob);
00130 }
00131
00132
00133
00134
00135 unsigned clsfy_random_forest::classify_oob(const vnl_vector<double> &input,
00136 const vcl_vector<vcl_vector<unsigned > >& oobIndices,
00137 unsigned this_index) const
00138 {
00139 vcl_vector<double > classProbs(1,0.0);
00140 class_probabilities_oob(classProbs,input,oobIndices,this_index);
00141 return (classProbs[0]>=0.5) ? 1 : 0;
00142 }
00143
00144
00145
00146
00147 vcl_string clsfy_random_forest::is_a() const
00148 {
00149 return vcl_string("clsfy_random_forest");
00150 }
00151
00152
00153
00154 bool clsfy_random_forest::is_class(vcl_string const& s) const
00155 {
00156 return s == clsfy_random_forest::is_a() || clsfy_classifier_base::is_class(s);
00157 }
00158
00159
00160
00161 short clsfy_random_forest::version_no() const
00162 {
00163 return 1;
00164 }
00165
00166
00167
00168 clsfy_classifier_base* clsfy_random_forest::clone() const
00169 {
00170 return new clsfy_random_forest(*this);
00171 }
00172
00173
00174
00175 void clsfy_random_forest::print_summary(vcl_ostream& os) const
00176 {
00177 os<<"clsfy_random_forest\t has "<<trees_.size()<<" trees"<<vcl_endl;
00178 }
00179
00180
00181
00182 void clsfy_random_forest::b_write(vsl_b_ostream& bfs) const
00183 {
00184 vcl_cout<<"clsfy_random_forest::b_write"<<vcl_endl;
00185 vsl_b_write(bfs,version_no());
00186 unsigned n=trees_.size();
00187 vsl_b_write(bfs,n);
00188 for (unsigned i=0; i<n;++i)
00189 {
00190 trees_[i]->b_write(bfs);
00191 }
00192 }
00193
00194
00195
00196 void clsfy_random_forest::b_read(vsl_b_istream& bfs)
00197 {
00198 if (!bfs) return;
00199
00200 prune();
00201 short version;
00202 vsl_b_read(bfs,version);
00203 switch (version)
00204 {
00205 case 1:
00206 {
00207 unsigned n;
00208 vsl_b_read(bfs,n);
00209 vcl_cout<<"Am attemptig to read in "<<n<<"\t trees"<<vcl_endl;
00210 trees_.reserve(n);
00211 for (unsigned i=0; i<n;++i)
00212 {
00213
00214 mbl_cloneable_ptr< clsfy_classifier_base> tree(new clsfy_binary_tree);
00215 trees_.push_back(tree);
00216 trees_.back()->b_read(bfs);
00217 }
00218 break;
00219 }
00220
00221 default:
00222 vcl_cerr << "I/O ERROR: clsfy_random_forest::b_read(vsl_b_istream&)\n"
00223 << " Unknown version number "<< version << '\n';
00224 bfs.is().clear(vcl_ios::badbit);
00225 }
00226 }
00227
00228 clsfy_random_forest::~clsfy_random_forest()
00229 {
00230 prune();
00231 }
00232
00233 void clsfy_random_forest::prune()
00234 {
00235 trees_.clear();
00236 }
00237
00238
00239
00240 unsigned clsfy_random_forest::n_dims() const
00241 {
00242 if (trees_.empty())
00243 return 0;
00244 else
00245 return trees_.front()->n_dims();
00246 }
00247
00248 clsfy_random_forest& clsfy_random_forest::operator+=(const clsfy_random_forest& forest2)
00249 {
00250 this->trees_.reserve(this->trees_.size()+forest2.trees_.size());
00251 this->trees_.insert(this->trees_.end(),
00252 forest2.trees_.begin(),forest2.trees_.end());
00253 return *this;
00254 }
00255
00256
00257
00258
00259
00260 void merge_sub_forests(const vcl_vector<vcl_string>& filenames,
00261 clsfy_random_forest& large_forest)
00262 {
00263 vcl_vector<vcl_string>::const_iterator fileIter=filenames.begin();
00264 vcl_vector<vcl_string>::const_iterator fileIterEnd=filenames.end();
00265 while (fileIter != fileIterEnd)
00266 {
00267 vsl_b_ifstream bfs_in(*fileIter);
00268 clsfy_random_forest subForest;
00269 vsl_b_read(bfs_in, subForest);
00270 bfs_in.close();
00271 large_forest.trees_.reserve(large_forest.trees_.size()+subForest.trees_.size());
00272 large_forest.trees_.insert(large_forest.trees_.end(),
00273 subForest.trees_.begin(),subForest.trees_.end());
00274 ++fileIter;
00275 }
00276 }
00277
00278
00279 void merge_sub_forests(const vcl_vector< clsfy_random_forest*>& sub_forests,
00280 clsfy_random_forest& large_forest)
00281 {
00282 vcl_vector<clsfy_random_forest*>::const_iterator subForestIter=sub_forests.begin();
00283 vcl_vector<clsfy_random_forest*>::const_iterator subForestIterEnd=sub_forests.end();
00284 while (subForestIter != subForestIterEnd)
00285 {
00286 const clsfy_random_forest& subForest=**subForestIter;
00287 large_forest.trees_.reserve(large_forest.trees_.size()+subForest.trees_.size());
00288 large_forest.trees_.insert(large_forest.trees_.end(),
00289 subForest.trees_.begin(),subForest.trees_.end());
00290 ++subForestIter;
00291 }
00292 }
00293
00294
00295 clsfy_random_forest operator+(const clsfy_random_forest& forest1,
00296 const clsfy_random_forest& forest2)
00297 {
00298 clsfy_random_forest mergedForest=forest1;
00299
00300 mergedForest.trees_.reserve(forest1.trees_.size()+forest2.trees_.size());
00301 mergedForest.trees_.insert(mergedForest.trees_.end(),
00302 forest2.trees_.begin(),forest2.trees_.end());
00303 return mergedForest;
00304 }
00305