00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010 #include "clsfy_binary_tree_builder.h"
00011 #include <clsfy/clsfy_binary_threshold_1d_gini_builder.h>
00012
00013 #include <vcl_iostream.h>
00014 #include <vcl_string.h>
00015 #include <vcl_algorithm.h>
00016 #include <vcl_numeric.h>
00017 #include <vcl_iterator.h>
00018 #include <vcl_cassert.h>
00019 #include <vcl_cstddef.h>
00020 #include <vsl/vsl_binary_loader.h>
00021 #include <mbl/mbl_stl.h>
00022 #include <clsfy/clsfy_k_nearest_neighbour.h>
00023
00024
00025
00026 clsfy_binary_tree_builder::clsfy_binary_tree_builder():
00027 max_depth_(-1),min_node_size_(5),nbranch_params_(-1),calc_test_error_(true)
00028 {
00029 unsigned long default_seed=123654987;
00030 seed_sampler(default_seed);
00031 }
00032
00033
00034
00035
00036 short clsfy_binary_tree_builder::version_no() const
00037 {
00038 return 1;
00039 }
00040
00041
00042
00043 vcl_string clsfy_binary_tree_builder::is_a() const
00044 {
00045 return vcl_string("clsfy_binary_tree_builder");
00046 }
00047
00048
00049
00050 bool clsfy_binary_tree_builder::is_class(vcl_string const& s) const
00051 {
00052 return s == clsfy_binary_tree_builder::is_a() || clsfy_builder_base::is_class(s);
00053 }
00054
00055
00056
00057 clsfy_builder_base* clsfy_binary_tree_builder::clone() const
00058 {
00059 return new clsfy_binary_tree_builder(*this);
00060 }
00061
00062
00063
00064 void clsfy_binary_tree_builder::print_summary(vcl_ostream& os) const
00065 {
00066 os << "max_depth = " << max_depth_;
00067 }
00068
00069
00070
00071 void clsfy_binary_tree_builder::b_write(vsl_b_ostream& bfs) const
00072 {
00073 vsl_b_write(bfs, version_no());
00074 vsl_b_write(bfs, max_depth_);
00075 vsl_b_write(bfs, min_node_size_);
00076 vsl_b_write(bfs, nbranch_params_);
00077 vsl_b_write(bfs,calc_test_error_);
00078 vcl_cerr << "clsfy_binary_tree_builder::b_write() NYI\n";
00079 }
00080
00081
00082
00083 void clsfy_binary_tree_builder::b_read(vsl_b_istream& bfs)
00084 {
00085 if (!bfs) return;
00086
00087 short version;
00088 vsl_b_read(bfs,version);
00089 switch (version)
00090 {
00091 case (1):
00092 vsl_b_read(bfs, max_depth_);
00093 vsl_b_read(bfs, min_node_size_);
00094 vsl_b_read(bfs, nbranch_params_);
00095 vsl_b_read(bfs,calc_test_error_);
00096
00097 break;
00098 default:
00099 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_tree_builder&)\n"
00100 << " Unknown version number "<< version << '\n';
00101 bfs.is().clear(vcl_ios::badbit);
00102 }
00103 }
00104
00105
00106
00107
00108
00109
00110
00111 double clsfy_binary_tree_builder::build(clsfy_classifier_base& classifier,
00112 mbl_data_wrapper<vnl_vector<double> >& inputs,
00113 unsigned nClasses,
00114 const vcl_vector<unsigned> &outputs) const
00115 {
00116 assert(classifier.is_class("clsfy_binary_tree"));
00117 assert(inputs.size()==outputs.size());
00118 assert(nClasses=1);
00119
00120
00121 clsfy_binary_tree &binary_tree = static_cast<clsfy_binary_tree&>(classifier);
00122 unsigned npoints=inputs.size();
00123 vcl_vector<vnl_vector<double> > vin(npoints);
00124
00125 inputs.reset();
00126 unsigned i=0;
00127 do
00128 {
00129 vin[i++] = inputs.current();
00130 } while (inputs.next());
00131
00132 assert(i==inputs.size());
00133
00134
00135 unsigned ndims=vin.front().size();
00136 base_indices_.resize(ndims);
00137 mbl_stl_increments(base_indices_.begin(),base_indices_.end(),0);
00138
00139 clsfy_binary_tree_op rootOp;
00140 clsfy_binary_tree_bnode* root=new clsfy_binary_tree_bnode(0,rootOp);
00141
00142
00143 vcl_set<unsigned> indices;
00144 mbl_stl_increments_n(vcl_inserter(indices,indices.end()),npoints,0);
00145
00146 build_a_node(vin,outputs,indices,root);
00147
00148 bool left=true;
00149 bool right=false;
00150
00151 bool pure=isNodePure(root->subIndicesL,outputs);
00152 if (!pure)
00153 {
00154 #if 0
00155 vcl_cout<<"Building the root left branch children"<<vcl_endl;
00156 #endif
00157 build_children(vin,outputs,root,left);
00158 }
00159 else
00160 {
00161 #if 0
00162 vcl_cout<<"Terminating the root left branch"<<vcl_endl;
00163 #endif
00164 add_terminator(vin,outputs,root,left,true);
00165 }
00166 pure=isNodePure(root->subIndicesR,outputs);
00167 if (!pure)
00168 {
00169 #if 0
00170 vcl_cout<<"Building the root right branch children"<<vcl_endl;
00171 #endif
00172 build_children(vin,outputs,root,right);
00173 }
00174 else
00175 {
00176 #if 0
00177 vcl_cout<<"Terminating the root right branch"<<vcl_endl;
00178 #endif
00179 add_terminator(vin,outputs,root,right,true);
00180 }
00181
00182
00183
00184 clsfy_binary_tree_node* classRoot=new clsfy_binary_tree_node(0,root->op_);
00185 set_node_prob(classRoot,root);
00186
00187 copy_children(root,classRoot);
00188 binary_tree.set_root(classRoot);
00189
00190 clsfy_binary_tree::remove_tree(root);
00191
00192 if (calc_test_error_)
00193 return clsfy_test_error(classifier, inputs, outputs);
00194 else
00195 return 0.0;
00196 }
00197
00198 void clsfy_binary_tree_builder::build_children(
00199 const vcl_vector<vnl_vector<double> >& vin,
00200 const vcl_vector<unsigned>& outputs,
00201 clsfy_binary_tree_bnode* parent,bool left) const
00202 {
00203 if (max_depth_>0)
00204 {
00205
00206 int depth=1;
00207 clsfy_binary_tree_bnode* pNode=parent;
00208 while (pNode)
00209 {
00210 pNode=static_cast<clsfy_binary_tree_bnode*>(pNode->parent_);
00211 ++depth;
00212 }
00213 if (depth>=max_depth_)
00214 {
00215
00216 vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00217 bool pure=isNodePure(subIndices,outputs);
00218 add_terminator(vin,outputs,parent,left,pure);
00219 return;
00220 }
00221 }
00222
00223
00224 vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00225 clsfy_binary_tree_op dummyOp;
00226 parent->add_child(dummyOp,left);
00227
00228
00229 clsfy_binary_tree_bnode* pChild=dynamic_cast< clsfy_binary_tree_bnode*>(left ? parent->left_child_ : parent->right_child_);
00230 build_a_node(vin,outputs,subIndices,pChild);
00231
00232
00233 if (pChild->subIndicesL.empty() || pChild->subIndicesR.empty())
00234 {
00235
00236
00237 delete pChild;
00238 if (left)
00239 parent->left_child_=0;
00240 else
00241 parent->right_child_=0;
00242
00243 vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00244 bool pure=isNodePure(subIndices,outputs);
00245 add_terminator(vin,outputs,parent,left,pure);
00246 return;
00247 }
00248
00249
00250 if (min_node_size_>0)
00251 {
00252 if (pChild->subIndicesL.size() < static_cast<vcl_size_t>(min_node_size_) ||
00253 pChild->subIndicesR.size() < static_cast<vcl_size_t>(min_node_size_) )
00254 {
00255
00256
00257 delete pChild;
00258 if (left)
00259 parent->left_child_=0;
00260 else
00261 parent->right_child_=0;
00262
00263 vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00264 bool pure=isNodePure(subIndices,outputs);
00265 add_terminator(vin,outputs,parent,left,pure);
00266 return;
00267 }
00268 }
00269
00270
00271 clsfy_binary_tree_bnode* pNode=pChild;
00272
00273
00274
00275 bool pure=isNodePure(pNode->subIndicesL,outputs);
00276 bool myLeft=true;
00277 bool myRight=false;
00278 if (!pure)
00279 {
00280 build_children(vin,outputs,pNode,myLeft);
00281 }
00282 else
00283 {
00284 add_terminator(vin,outputs,pNode,myLeft,true);
00285 }
00286
00287
00288
00289 pure=isNodePure(pNode->subIndicesR,outputs);
00290 if (!pure)
00291 {
00292 build_children(vin,outputs,pNode,myRight);
00293 }
00294 else
00295 {
00296 add_terminator(vin,outputs,pNode,myRight,true);
00297 }
00298 }
00299
00300
00301 void clsfy_binary_tree_builder::copy_children(clsfy_binary_tree_bnode* pBuilderNode,
00302 clsfy_binary_tree_node* pNode) const
00303 {
00304 bool left=true;
00305 set_node_prob(pNode,pBuilderNode);
00306
00307 if (pBuilderNode->left_child_)
00308 {
00309 pNode->add_child(pBuilderNode->left_child_->op_,left);
00310 copy_children(static_cast<clsfy_binary_tree_bnode*>(pBuilderNode->left_child_),
00311 pNode->left_child_);
00312 }
00313 if (pBuilderNode->right_child_)
00314 {
00315 pNode->add_child(pBuilderNode->right_child_->op_,!left);
00316 copy_children(static_cast<clsfy_binary_tree_bnode*>(pBuilderNode->right_child_),
00317 pNode->right_child_);
00318 }
00319 }
00320
00321 void clsfy_binary_tree_builder::build_a_node(
00322 const vcl_vector<vnl_vector<double> >& vin,
00323 const vcl_vector<unsigned>& outputs,
00324 const vcl_set<unsigned >& subIndices,
00325 clsfy_binary_tree_bnode* pNode) const
00326 {
00327 clsfy_binary_threshold_1d_gini_builder tbuilder;
00328 unsigned ndims=vin.front().size();
00329 unsigned ndimsUsed=ndims;
00330 vcl_vector<unsigned > param_indices;
00331 if (nbranch_params_>0)
00332 {
00333 ndimsUsed=nbranch_params_;
00334 ndimsUsed=vcl_min(ndimsUsed,ndims);
00335 if (ndimsUsed<ndims)
00336 {
00337
00338
00339
00340 randomise_parameters(ndims,param_indices);
00341 }
00342 else
00343 {
00344 ndimsUsed=ndims;
00345 param_indices.resize(ndims);
00346 mbl_stl_increments(param_indices.begin(),param_indices.end(),0);
00347 }
00348 }
00349 else
00350 {
00351 ndimsUsed=ndims;
00352 param_indices.resize(ndims);
00353 mbl_stl_increments(param_indices.begin(),param_indices.end(),0);
00354 }
00355 vcl_vector<clsfy_classifier_1d*> pBranchClassifiers(ndims,0);
00356 vnl_vector<double > wts(subIndices.size());
00357 wts.fill(1.0/double (vin.size())-1.0E-12);
00358 unsigned npoints=subIndices.size();
00359 vcl_vector<unsigned > subOutputs;
00360 subOutputs.reserve(npoints);
00361 vcl_transform(subIndices.begin(),subIndices.end(),
00362 vcl_back_inserter(subOutputs),
00363 mbl_stl_index_functor<unsigned >(outputs));
00364
00365 double minError=1.0E30;
00366 unsigned ibest=0;
00367
00368 vnl_vector<double > data(npoints);
00369
00370
00371
00372
00373
00374 unsigned npasses=1;
00375 if (ndimsUsed<ndims) npasses=2;
00376 for (unsigned ipass=0;ipass<npasses;++ipass)
00377 {
00378 unsigned istart=0;
00379 unsigned nmax=ndimsUsed;
00380 if (ipass>0)
00381 {
00382 istart=ndimsUsed;
00383 nmax=ndims;
00384 }
00385 for (unsigned idim=istart;idim<nmax;++idim)
00386 {
00387 pBranchClassifiers[idim] = tbuilder.new_classifier();
00388 vcl_set<unsigned >::const_iterator indIter=subIndices.begin();
00389 vcl_set<unsigned >::const_iterator indIterEnd=subIndices.end();
00390 unsigned ipt=0;
00391 while (indIter != indIterEnd)
00392 {
00393 data[ipt] = vin[*indIter][param_indices[idim]];
00394 ++ipt;
00395 ++indIter;
00396 }
00397
00398 double error=tbuilder.build_gini(*pBranchClassifiers[idim],
00399 data,subOutputs);
00400 if (error<minError)
00401 {
00402 minError=error;
00403 ibest=idim;
00404 }
00405 }
00406
00407 pNode->subIndicesL.clear();
00408 pNode->subIndicesR.clear();
00409 clsfy_binary_tree_op op(0,param_indices[ibest]);
00410 op.classifier() = *(static_cast<clsfy_binary_threshold_1d*>(pBranchClassifiers[ibest]));
00411
00412 pNode->op_=op;
00413
00414
00415 vcl_set<unsigned >::const_iterator indIter=subIndices.begin();
00416 vcl_set<unsigned >::const_iterator indIterEnd=subIndices.end();
00417 vcl_set<unsigned >& subIndicesL=pNode->subIndicesL;
00418 vcl_set<unsigned >& subIndicesR=pNode->subIndicesR;
00419 while (indIter != indIterEnd)
00420 {
00421 double x = vin[*indIter][param_indices[ibest]];
00422 if (pBranchClassifiers[ibest]->classify(x)==0)
00423 subIndicesL.insert(*indIter);
00424 else
00425 subIndicesR.insert(*indIter);
00426 ++indIter;
00427 }
00428
00429 if (!subIndicesL.empty() && !subIndicesR.empty())
00430 break;
00431 }
00432 mbl_stl_clean(pBranchClassifiers.begin(),pBranchClassifiers.end());
00433 }
00434
00435 bool clsfy_binary_tree_builder::isNodePure(const vcl_set<unsigned >& subIndices,
00436 const vcl_vector<unsigned>& outputs) const
00437 {
00438 if (subIndices.empty()) return true;
00439 vcl_set<unsigned >::const_iterator indIter=subIndices.begin();
00440 vcl_set<unsigned >::const_iterator indIterEnd=subIndices.end();
00441
00442 unsigned class0=outputs[*indIter];
00443 while (indIter != indIterEnd)
00444 {
00445 if (outputs[*indIter] != class0)
00446 return false;
00447 ++indIter;
00448 }
00449 return true;
00450 }
00451
00452
00453
00454 void clsfy_binary_tree_builder::add_terminator(
00455 const vcl_vector<vnl_vector<double> >& vin,
00456 const vcl_vector<unsigned>& outputs,
00457 clsfy_binary_tree_bnode* parent,
00458 bool left, bool pure) const
00459 {
00460 double thresholdBig=1.0E30;
00461
00462 int dummyIndex=0;
00463 clsfy_binary_tree_op dummyOp(0,dummyIndex);
00464
00465
00466 unsigned classification=0;
00467 double prob=0.5;
00468 if (pure)
00469 {
00470 if (left)
00471 {
00472 if (!(parent->subIndicesL.empty()))
00473 classification=outputs[*(parent->subIndicesL.begin())];
00474 }
00475 else
00476 {
00477 classification=1;
00478 if (!(parent->subIndicesR.empty()))
00479 classification=outputs[*(parent->subIndicesR.begin())];
00480 }
00481 prob = (classification==1 ? 1.0 : 0.0);
00482 }
00483 else
00484 {
00485 vcl_set<unsigned >& indices=(left ? parent->subIndicesL : parent->subIndicesR);
00486 vcl_set<unsigned >::iterator indexIter=indices.begin();
00487 vcl_set<unsigned >::iterator indexIterEnd=indices.end();
00488 unsigned n1=0;
00489 while (indexIter != indexIterEnd)
00490 {
00491 if (outputs[*indexIter]==1)
00492 ++n1;
00493 ++indexIter;
00494 }
00495 prob=double (n1)/double (indices.size());
00496 classification = (prob>0.5 ? 1 : 0);
00497 }
00498 double parity=(classification==0 ? 1.0 : -1.0);
00499 dummyOp.classifier().set(1.0,parity*thresholdBig);
00500
00501 parent->add_child(dummyOp,left);
00502
00503 if (left)
00504 parent->left_child_->prob_=prob;
00505 else
00506 parent->right_child_->prob_=prob;
00507 }
00508
00509
00510
00511
00512 clsfy_classifier_base* clsfy_binary_tree_builder::new_classifier() const
00513 {
00514 return new clsfy_binary_tree();
00515 }
00516
00517 void clsfy_binary_tree_builder::randomise_parameters(unsigned ndimsUsed,
00518 vcl_vector<unsigned >& param_indices) const
00519 {
00520
00521 param_indices.resize(base_indices_.size());
00522
00523 vcl_random_shuffle(base_indices_.begin(),base_indices_.end(),random_sampler_);
00524 vcl_copy(base_indices_.begin(),base_indices_.end(),
00525 param_indices.begin());
00526 }
00527
00528
00529 void clsfy_binary_tree_builder::seed_sampler(unsigned long seed)
00530 {
00531 random_sampler_.reseed(seed);
00532 }
00533
00534 void clsfy_binary_tree_builder::set_node_prob(clsfy_binary_tree_node* pNode,
00535 clsfy_binary_tree_bnode* pBuilderNode) const
00536 {
00537 pNode->prob_ = pBuilderNode->prob_;
00538 }
00539
00540
00541 clsfy_binary_tree_node* clsfy_binary_tree_bnode::create_child(const clsfy_binary_tree_op& op)
00542 {
00543 return new clsfy_binary_tree_bnode(this,op);
00544 }
00545
00546 clsfy_binary_tree_bnode::~clsfy_binary_tree_bnode()
00547 {
00548 }