contrib/mul/clsfy/clsfy_binary_tree_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree_builder.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 // \brief Implement a binary_tree classifier builder
00008 // \author Martin Roberts
00009 
00010 #include "clsfy_binary_tree_builder.h"
00011 #include <clsfy/clsfy_binary_threshold_1d_gini_builder.h>
00012 
00013 #include <vcl_iostream.h>
00014 #include <vcl_string.h>
00015 #include <vcl_algorithm.h>
00016 #include <vcl_numeric.h>
00017 #include <vcl_iterator.h>
00018 #include <vcl_cassert.h>
00019 #include <vcl_cstddef.h>
00020 #include <vsl/vsl_binary_loader.h>
00021 #include <mbl/mbl_stl.h>
00022 #include <clsfy/clsfy_k_nearest_neighbour.h>
00023 
00024 //=======================================================================
00025 
00026 clsfy_binary_tree_builder::clsfy_binary_tree_builder():
00027     max_depth_(-1),min_node_size_(5),nbranch_params_(-1),calc_test_error_(true)
00028 {
00029     unsigned long default_seed=123654987;
00030     seed_sampler(default_seed);
00031 }
00032 
00033 
00034 //=======================================================================
00035 
00036 short clsfy_binary_tree_builder::version_no() const
00037 {
00038     return 1;
00039 }
00040 
00041 //=======================================================================
00042 
00043 vcl_string clsfy_binary_tree_builder::is_a() const
00044 {
00045     return vcl_string("clsfy_binary_tree_builder");
00046 }
00047 
00048 //=======================================================================
00049 
00050 bool clsfy_binary_tree_builder::is_class(vcl_string const& s) const
00051 {
00052     return s == clsfy_binary_tree_builder::is_a() || clsfy_builder_base::is_class(s);
00053 }
00054 
00055 //=======================================================================
00056 
00057 clsfy_builder_base* clsfy_binary_tree_builder::clone() const
00058 {
00059     return new clsfy_binary_tree_builder(*this);
00060 }
00061 
00062 //=======================================================================
00063 
00064 void clsfy_binary_tree_builder::print_summary(vcl_ostream& os) const
00065 {
00066     os << "max_depth = " << max_depth_;
00067 }
00068 
00069 //=======================================================================
00070 
00071 void clsfy_binary_tree_builder::b_write(vsl_b_ostream& bfs) const
00072 {
00073     vsl_b_write(bfs, version_no());
00074     vsl_b_write(bfs, max_depth_);
00075     vsl_b_write(bfs, min_node_size_);
00076     vsl_b_write(bfs, nbranch_params_);
00077     vsl_b_write(bfs,calc_test_error_);
00078     vcl_cerr << "clsfy_binary_tree_builder::b_write() NYI\n";
00079 }
00080 
00081 //=======================================================================
00082 
00083 void clsfy_binary_tree_builder::b_read(vsl_b_istream& bfs)
00084 {
00085     if (!bfs) return;
00086 
00087     short version;
00088     vsl_b_read(bfs,version);
00089     switch (version)
00090     {
00091         case (1):
00092             vsl_b_read(bfs, max_depth_);
00093             vsl_b_read(bfs, min_node_size_);
00094             vsl_b_read(bfs, nbranch_params_);
00095             vsl_b_read(bfs,calc_test_error_);
00096 
00097             break;
00098         default:
00099             vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_tree_builder&)\n"
00100                      << "           Unknown version number "<< version << '\n';
00101             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00102     }
00103 }
00104 
00105 //=======================================================================
00106 
00107 //: Build model from data
00108 // return the mean error over the training set.
00109 // For many classifiers, you may use nClasses==1 to
00110 // indicate a binary classifier
00111 double clsfy_binary_tree_builder::build(clsfy_classifier_base& classifier,
00112                                         mbl_data_wrapper<vnl_vector<double> >& inputs,
00113                                         unsigned nClasses,
00114                                         const vcl_vector<unsigned> &outputs) const
00115 {
00116     assert(classifier.is_class("clsfy_binary_tree")); // equiv to dynamic_cast<> != 0
00117     assert(inputs.size()==outputs.size());
00118     assert(nClasses=1);
00119 
00120 
00121     clsfy_binary_tree &binary_tree = static_cast<clsfy_binary_tree&>(classifier);
00122     unsigned npoints=inputs.size();
00123     vcl_vector<vnl_vector<double> > vin(npoints);
00124 
00125     inputs.reset();
00126     unsigned i=0;
00127     do
00128     {
00129         vin[i++] = inputs.current();
00130     } while (inputs.next());
00131 
00132     assert(i==inputs.size());
00133 
00134 
00135     unsigned ndims=vin.front().size();
00136     base_indices_.resize(ndims);
00137     mbl_stl_increments(base_indices_.begin(),base_indices_.end(),0);
00138 
00139     clsfy_binary_tree_op rootOp;
00140     clsfy_binary_tree_bnode* root=new clsfy_binary_tree_bnode(0,rootOp);
00141 
00142     // Start with all indices
00143     vcl_set<unsigned> indices;
00144     mbl_stl_increments_n(vcl_inserter(indices,indices.end()),npoints,0);
00145     // Build the root node starting from all indices
00146     build_a_node(vin,outputs,indices,root);
00147 
00148     bool left=true;
00149     bool right=false;
00150 
00151     bool pure=isNodePure(root->subIndicesL,outputs);
00152     if (!pure)
00153     { // Build the left branch children (recursively)
00154 #if 0
00155         vcl_cout<<"Building the root left branch children"<<vcl_endl;
00156 #endif
00157         build_children(vin,outputs,root,left);
00158     }
00159     else
00160     {
00161 #if 0
00162         vcl_cout<<"Terminating the root left branch"<<vcl_endl;
00163 #endif
00164         add_terminator(vin,outputs,root,left,true);
00165     }
00166     pure=isNodePure(root->subIndicesR,outputs);
00167     if (!pure)
00168     {    // Build the right branch children (recursively)
00169 #if 0
00170         vcl_cout<<"Building the root right branch children"<<vcl_endl;
00171 #endif
00172         build_children(vin,outputs,root,right);
00173     }
00174     else
00175     {
00176 #if 0
00177         vcl_cout<<"Terminating the root right branch"<<vcl_endl;
00178 #endif
00179         add_terminator(vin,outputs,root,right,true);
00180     }
00181 
00182 
00183     // Then copy into the classifier (like b_write, b_read)
00184     clsfy_binary_tree_node* classRoot=new clsfy_binary_tree_node(0,root->op_);
00185     set_node_prob(classRoot,root);
00186 
00187     copy_children(root,classRoot);
00188     binary_tree.set_root(classRoot);
00189 
00190     clsfy_binary_tree::remove_tree(root);
00191 
00192     if (calc_test_error_)
00193         return clsfy_test_error(classifier, inputs, outputs);
00194     else
00195         return 0.0;
00196 }
00197 
00198 void clsfy_binary_tree_builder::build_children(
00199     const vcl_vector<vnl_vector<double> >& vin,
00200     const vcl_vector<unsigned>& outputs,
00201     clsfy_binary_tree_bnode* parent,bool left) const
00202 {
00203     if (max_depth_>0)
00204     {
00205         // Validate depth
00206         int depth=1;
00207         clsfy_binary_tree_bnode* pNode=parent;
00208         while (pNode)
00209         {
00210             pNode=static_cast<clsfy_binary_tree_bnode*>(pNode->parent_);
00211             ++depth;
00212         }
00213         if (depth>=max_depth_)
00214         {
00215             // Can't go any deeper on this branch
00216             vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00217             bool pure=isNodePure(subIndices,outputs);
00218             add_terminator(vin,outputs,parent,left,pure);
00219             return;
00220         }
00221     }
00222 
00223 
00224     vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00225     clsfy_binary_tree_op dummyOp;
00226     parent->add_child(dummyOp,left);
00227 
00228 
00229     clsfy_binary_tree_bnode* pChild=dynamic_cast< clsfy_binary_tree_bnode*>(left ? parent->left_child_ : parent->right_child_);
00230     build_a_node(vin,outputs,subIndices,pChild);
00231 
00232     // Check that this actually managed to produce a split (in case we have homogeneous data)
00233     if (pChild->subIndicesL.empty() || pChild->subIndicesR.empty())
00234     {
00235         // We should not have added this child as it's not managed to produce a split
00236         // Backtrack
00237         delete pChild;
00238         if (left)
00239             parent->left_child_=0;
00240         else
00241             parent->right_child_=0;
00242         // Can't go any deeper on this branch
00243         vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00244         bool pure=isNodePure(subIndices,outputs);
00245         add_terminator(vin,outputs,parent,left,pure);
00246         return;
00247     }
00248 
00249     // May need to check min dataset size in next generation children
00250     if (min_node_size_>0)
00251     {
00252         if (pChild->subIndicesL.size() < static_cast<vcl_size_t>(min_node_size_) ||
00253             pChild->subIndicesR.size() < static_cast<vcl_size_t>(min_node_size_) )
00254         {
00255             // We should not have added this child as it's based on too small a split
00256             // Backtrack
00257             delete pChild;
00258             if (left)
00259                 parent->left_child_=0;
00260             else
00261                 parent->right_child_=0;
00262             // Can't go any deeper on this branch
00263             vcl_set<unsigned >& subIndices=(left ? parent->subIndicesL : parent->subIndicesR);
00264             bool pure=isNodePure(subIndices,outputs);
00265             add_terminator(vin,outputs,parent,left,pure);
00266             return;
00267         }
00268     }
00269 
00270 
00271     clsfy_binary_tree_bnode* pNode=pChild;
00272     // Go on down the left branch of the split we just introduced
00273 
00274     // Check if left node is pure
00275     bool pure=isNodePure(pNode->subIndicesL,outputs);
00276     bool myLeft=true;
00277     bool myRight=false;
00278     if (!pure)
00279     {
00280         build_children(vin,outputs,pNode,myLeft);
00281     }
00282     else // Add a dummy classifier on a pure node so it always returns that class
00283     {
00284         add_terminator(vin,outputs,pNode,myLeft,true);
00285     }
00286 
00287     // Go on down the right branch of the split we just introduced
00288     // Check if right node is pure
00289     pure=isNodePure(pNode->subIndicesR,outputs);
00290     if (!pure)
00291     {
00292         build_children(vin,outputs,pNode,myRight);
00293     }
00294     else
00295     {
00296         add_terminator(vin,outputs,pNode,myRight,true);
00297     }
00298 }
00299 
00300 
00301 void clsfy_binary_tree_builder::copy_children(clsfy_binary_tree_bnode* pBuilderNode,
00302                                               clsfy_binary_tree_node* pNode) const
00303 {
00304     bool left=true;
00305     set_node_prob(pNode,pBuilderNode);
00306 
00307     if (pBuilderNode->left_child_)
00308     {
00309         pNode->add_child(pBuilderNode->left_child_->op_,left);
00310         copy_children(static_cast<clsfy_binary_tree_bnode*>(pBuilderNode->left_child_),
00311                       pNode->left_child_);
00312     }
00313     if (pBuilderNode->right_child_)
00314     {
00315         pNode->add_child(pBuilderNode->right_child_->op_,!left);
00316         copy_children(static_cast<clsfy_binary_tree_bnode*>(pBuilderNode->right_child_),
00317                       pNode->right_child_);
00318     }
00319 }
00320 
00321 void clsfy_binary_tree_builder::build_a_node(
00322     const vcl_vector<vnl_vector<double> >& vin,
00323     const vcl_vector<unsigned>& outputs,
00324     const vcl_set<unsigned >& subIndices,
00325     clsfy_binary_tree_bnode* pNode) const
00326 {
00327     clsfy_binary_threshold_1d_gini_builder tbuilder;
00328     unsigned ndims=vin.front().size();
00329     unsigned ndimsUsed=ndims;
00330     vcl_vector<unsigned > param_indices;
00331     if (nbranch_params_>0) // Random forest style random subset selection
00332     {
00333         ndimsUsed=nbranch_params_;
00334         ndimsUsed=vcl_min(ndimsUsed,ndims);
00335         if (ndimsUsed<ndims)
00336         {
00337             // Note always do a full random permutation as the ones beyond ndimsUsed
00338             // may be needed as a fallback in exceptional cases where the initial random
00339             // subset cannot split the data
00340             randomise_parameters(ndims,param_indices);
00341         }
00342         else
00343         {
00344             ndimsUsed=ndims;
00345             param_indices.resize(ndims);
00346             mbl_stl_increments(param_indices.begin(),param_indices.end(),0);
00347         }
00348     }
00349     else
00350     {
00351         ndimsUsed=ndims;
00352         param_indices.resize(ndims);
00353         mbl_stl_increments(param_indices.begin(),param_indices.end(),0);
00354     }
00355     vcl_vector<clsfy_classifier_1d*> pBranchClassifiers(ndims,0);
00356     vnl_vector<double > wts(subIndices.size());
00357     wts.fill(1.0/double (vin.size())-1.0E-12);
00358     unsigned npoints=subIndices.size();
00359     vcl_vector<unsigned > subOutputs;
00360     subOutputs.reserve(npoints);
00361     vcl_transform(subIndices.begin(),subIndices.end(),
00362                   vcl_back_inserter(subOutputs),
00363                   mbl_stl_index_functor<unsigned >(outputs));
00364 
00365     double minError=1.0E30;
00366     unsigned ibest=0;
00367 
00368     vnl_vector<double > data(npoints);
00369 
00370     // May need a second pass because there may be a subset of homogeneous parameters
00371     // if we are only using a random subset which cannot produce a split
00372     // So try first with the random subset, and then with all if that fails
00373     // Note this assumes param_indices contains the complete set
00374     unsigned npasses=1;
00375     if (ndimsUsed<ndims) npasses=2;
00376     for (unsigned ipass=0;ipass<npasses;++ipass)
00377     {
00378         unsigned istart=0;
00379         unsigned nmax=ndimsUsed;
00380         if (ipass>0)
00381         {
00382             istart=ndimsUsed;
00383             nmax=ndims;
00384         }
00385         for (unsigned idim=istart;idim<nmax;++idim)
00386         {
00387             pBranchClassifiers[idim] = tbuilder.new_classifier();
00388             vcl_set<unsigned >::const_iterator indIter=subIndices.begin();
00389             vcl_set<unsigned >::const_iterator indIterEnd=subIndices.end();
00390             unsigned ipt=0;
00391             while (indIter != indIterEnd)
00392             {
00393                 data[ipt] = vin[*indIter][param_indices[idim]];
00394                 ++ipt;
00395                 ++indIter;
00396             }
00397 
00398             double error=tbuilder.build_gini(*pBranchClassifiers[idim],
00399                                              data,subOutputs);
00400             if (error<minError)
00401             {
00402                 minError=error;
00403                 ibest=idim;
00404             }
00405         }
00406 
00407         pNode->subIndicesL.clear();
00408         pNode->subIndicesR.clear();
00409         clsfy_binary_tree_op op(0,param_indices[ibest]);
00410         op.classifier() = *(static_cast<clsfy_binary_threshold_1d*>(pBranchClassifiers[ibest]));
00411 
00412         pNode->op_=op;
00413 
00414         // Now reapply to all relevant data to construct the subset split
00415         vcl_set<unsigned >::const_iterator indIter=subIndices.begin();
00416         vcl_set<unsigned >::const_iterator indIterEnd=subIndices.end();
00417         vcl_set<unsigned >& subIndicesL=pNode->subIndicesL;
00418         vcl_set<unsigned >& subIndicesR=pNode->subIndicesR;
00419         while (indIter != indIterEnd)
00420         {
00421             double x = vin[*indIter][param_indices[ibest]];
00422             if (pBranchClassifiers[ibest]->classify(x)==0)
00423                 subIndicesL.insert(*indIter);
00424             else
00425                 subIndicesR.insert(*indIter);
00426             ++indIter;
00427         }
00428 
00429         if (!subIndicesL.empty() && !subIndicesR.empty()) // Success - it really has split
00430             break; // no second pass needed
00431     }
00432     mbl_stl_clean(pBranchClassifiers.begin(),pBranchClassifiers.end());
00433 }
00434 
00435 bool clsfy_binary_tree_builder::isNodePure(const vcl_set<unsigned >& subIndices,
00436                                            const vcl_vector<unsigned>& outputs) const
00437 {
00438     if (subIndices.empty()) return true;
00439     vcl_set<unsigned >::const_iterator indIter=subIndices.begin();
00440     vcl_set<unsigned >::const_iterator indIterEnd=subIndices.end();
00441 
00442     unsigned class0=outputs[*indIter];
00443     while (indIter != indIterEnd)
00444     {
00445         if (outputs[*indIter] != class0)
00446             return false;
00447         ++indIter;
00448     }
00449     return true;
00450 }
00451 
00452 //: Add dummy node to represent a pure node
00453 // The threshold is set either very low or very high
00454 void clsfy_binary_tree_builder::add_terminator(
00455     const vcl_vector<vnl_vector<double> >& vin,
00456     const vcl_vector<unsigned>& outputs,
00457     clsfy_binary_tree_bnode* parent,
00458     bool left, bool pure) const
00459 {
00460     double thresholdBig=1.0E30;
00461 
00462     int dummyIndex=0;
00463     clsfy_binary_tree_op dummyOp(0,dummyIndex);
00464 
00465 
00466     unsigned classification=0;
00467     double prob=0.5;
00468     if (pure)
00469     {
00470         if (left)
00471         {
00472             if (!(parent->subIndicesL.empty()))
00473                 classification=outputs[*(parent->subIndicesL.begin())];
00474         }
00475         else
00476         {
00477             classification=1;
00478             if (!(parent->subIndicesR.empty()))
00479                 classification=outputs[*(parent->subIndicesR.begin())];
00480         }
00481         prob = (classification==1 ? 1.0 : 0.0);
00482     }
00483     else // Mixed node - assess ratio of classes
00484     {
00485         vcl_set<unsigned >& indices=(left ? parent->subIndicesL : parent->subIndicesR);
00486         vcl_set<unsigned >::iterator indexIter=indices.begin();
00487         vcl_set<unsigned >::iterator indexIterEnd=indices.end();
00488         unsigned n1=0;
00489         while (indexIter != indexIterEnd)
00490         {
00491             if (outputs[*indexIter]==1)
00492                 ++n1;
00493             ++indexIter;
00494         }
00495         prob=double (n1)/double (indices.size());
00496         classification = (prob>0.5 ? 1 : 0);
00497     }
00498     double parity=(classification==0 ? 1.0 : -1.0);
00499     dummyOp.classifier().set(1.0,parity*thresholdBig);
00500 
00501     parent->add_child(dummyOp,left);
00502 
00503     if (left)
00504         parent->left_child_->prob_=prob;
00505     else
00506         parent->right_child_->prob_=prob;
00507 }
00508 
00509 //=======================================================================
00510 //: Create empty classifier
00511 // Caller is responsible for deletion
00512 clsfy_classifier_base* clsfy_binary_tree_builder::new_classifier() const
00513 {
00514     return new clsfy_binary_tree();
00515 }
00516 
00517 void  clsfy_binary_tree_builder::randomise_parameters(unsigned ndimsUsed,
00518                                                       vcl_vector<unsigned  >& param_indices) const
00519 {
00520     // In fact it shuffles all indices (in case the random subset does not produce a split)
00521     param_indices.resize(base_indices_.size());
00522 
00523     vcl_random_shuffle(base_indices_.begin(),base_indices_.end(),random_sampler_);
00524     vcl_copy(base_indices_.begin(),base_indices_.end(),
00525              param_indices.begin());
00526 }
00527 
00528 
00529 void clsfy_binary_tree_builder::seed_sampler(unsigned long seed)
00530 {
00531     random_sampler_.reseed(seed);
00532 }
00533 
00534 void clsfy_binary_tree_builder::set_node_prob(clsfy_binary_tree_node* pNode,
00535                                               clsfy_binary_tree_bnode* pBuilderNode) const
00536 {
00537     pNode->prob_ = pBuilderNode->prob_;
00538 }
00539 
00540 //=========================== Helper Classes =============================
00541 clsfy_binary_tree_node* clsfy_binary_tree_bnode::create_child(const clsfy_binary_tree_op& op)
00542 {
00543     return new clsfy_binary_tree_bnode(this,op);
00544 }
00545 
00546 clsfy_binary_tree_bnode::~clsfy_binary_tree_bnode()
00547 {
00548 }