contrib/mul/clsfy/clsfy_binary_tree.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree.cxx
00002 #include "clsfy_binary_tree.h"
00003 //:
00004 // \file
00005 // \brief Binary tree classifier
00006 // \author Martin Roberts
00007 
00008 #include <vcl_string.h>
00009 #include <vcl_deque.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_iterator.h>
00012 #include <vcl_cmath.h>
00013 #include <vcl_cassert.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_vector_io.h>
00016 #include <vnl/io/vnl_io_vector.h>
00017 #include <mbl/mbl_stl.h>
00018 
00019 
00020 clsfy_binary_tree::clsfy_binary_tree(const clsfy_binary_tree& srcTree)
00021 {
00022     root_=cache_node_=0;
00023     copy(srcTree);
00024 }
00025 
00026 clsfy_binary_tree& clsfy_binary_tree::operator=(const clsfy_binary_tree& srcTree)
00027 {
00028     if (&srcTree != this)
00029     {
00030         copy(srcTree);
00031     }
00032     return *this;
00033 }
00034 
00035 void clsfy_binary_tree::copy(const clsfy_binary_tree& srcTree)
00036 {
00037     remove_tree(root_);
00038     //Then copy into the classifier
00039     if (srcTree.root_)
00040     {
00041         root_ = new clsfy_binary_tree_node(0,srcTree.root_->op_);
00042         root_->prob_ = srcTree.root_->prob_;
00043         copy_children(srcTree.root_,root_);
00044     }
00045     else
00046         root_=0;
00047     cache_node_ = root_;
00048 }
00049 
00050 void clsfy_binary_tree::copy_children(clsfy_binary_tree_node* pSrcNode,clsfy_binary_tree_node* pNode)
00051 {
00052     bool left=true;
00053     pNode->prob_ = pSrcNode->prob_;
00054     if (pSrcNode->left_child_)
00055     {
00056         pNode->add_child(pSrcNode->left_child_->op_,left);
00057         copy_children(pSrcNode->left_child_,
00058                       pNode->left_child_);
00059     }
00060     if (pSrcNode->right_child_)
00061     {
00062         pNode->add_child(pSrcNode->right_child_->op_,!left);
00063         copy_children(pSrcNode->right_child_,
00064                       pNode->right_child_);
00065     }
00066 }
00067 
00068 //=======================================================================
00069 //: Return the classification of the given probe vector.
00070 unsigned clsfy_binary_tree::classify(const vnl_vector<double> &input) const
00071 {
00072     unsigned outClass=0;
00073     //Traverse the tree
00074     clsfy_binary_tree_node* pNode=root_;
00075     if (!pNode)
00076     {
00077         vcl_cerr<<"WARNING - empty tree in clsfy_binary_tree::classify\n"
00078                 <<"Return default classification zero\n";
00079         return 0;
00080     }
00081     clsfy_binary_tree_node* pChild=0;
00082     do //Keep dropping down the tree till reach base level
00083     {
00084         pNode->op_.set_data(input);
00085         unsigned indicator=pNode->op_.classify();
00086         if (indicator==0)
00087         {
00088             pChild=pNode->left_child_;
00089         }
00090         else
00091         {
00092             pChild=pNode->right_child_;
00093         }
00094         if (pChild)
00095             pNode=pChild;
00096         else
00097         {
00098             cache_node_ = pNode; //Store final node (in case probability accessed)
00099             outClass=(pNode->prob_>0.5 ? 1 : 0);
00100         }
00101     }while (pChild);
00102 
00103     return outClass;
00104 }
00105 
00106 //=======================================================================
00107 //: Return a probability like value that the input being in each class.
00108 // output(i) i<<nClasses, contains the probability that the input is in class i
00109 void clsfy_binary_tree::class_probabilities(vcl_vector<double>& outputs,
00110                                             vnl_vector<double>const& input) const
00111 {
00112     outputs.resize(1);
00113     classify(input);
00114     outputs[0] = cache_node_->prob_;
00115 }
00116 
00117 
00118 //=======================================================================
00119 //: The dimensionality of input vectors.
00120 unsigned clsfy_binary_tree::n_dims() const
00121 {
00122     clsfy_binary_tree_node* pNode=root_;
00123     if (pNode)
00124         return pNode->op_.ndims();
00125     else
00126         return 0;
00127 }
00128 
00129 //=======================================================================
00130 //: This value has properties of a Log likelihood of being in class (binary classifiers only)
00131 // class probability = exp(logL) / (1+exp(logL))
00132 double clsfy_binary_tree::log_l(const vnl_vector<double> &input) const
00133 {
00134     vcl_vector<double > probs;
00135     class_probabilities(probs,input);
00136     double p1=probs[0];
00137     double p0=1-p1;
00138     const double epsilon=1.0E-8;
00139     if (p0<epsilon) p0=epsilon;
00140     double L=vcl_log(p1/p0);
00141 
00142     return L;
00143 }
00144 
00145 
00146 //=======================================================================
00147 
00148 vcl_string clsfy_binary_tree::is_a() const
00149 {
00150     return vcl_string("clsfy_binary_tree");
00151 }
00152 
00153 //=======================================================================
00154 
00155 bool clsfy_binary_tree::is_class(vcl_string const& s) const
00156 {
00157     return s == clsfy_binary_tree::is_a() || clsfy_classifier_base::is_class(s);
00158 }
00159 
00160 //=======================================================================
00161 
00162 short clsfy_binary_tree::version_no() const
00163 {
00164     return 1;
00165 }
00166 
00167 //=======================================================================
00168 
00169 clsfy_classifier_base* clsfy_binary_tree::clone() const
00170 {
00171     return new clsfy_binary_tree(*this);
00172 }
00173 
00174 //=======================================================================
00175 
00176 void clsfy_binary_tree::print_summary(vcl_ostream& /*os*/) const
00177 {
00178 }
00179 
00180 //=======================================================================
00181 
00182 void clsfy_binary_tree::b_write(vsl_b_ostream& bfs) const
00183 {
00184     vsl_b_write(bfs,version_no());
00185     int nodeId=0; //used numeric ids for parent child relations
00186     // -1 means none
00187     vcl_deque<clsfy_binary_tree_node*> stack;
00188     vcl_deque<clsfy_binary_tree_node*> outlist;
00189     vcl_vector<graph_rep> arcs;
00190     clsfy_binary_tree_node* pNode=root_;
00191 
00192     stack.push_back(pNode);
00193     pNode->nodeId_=0;
00194     while (!stack.empty())
00195     {
00196         pNode=stack.front();
00197         stack.pop_front();
00198         outlist.push_back(pNode);
00199         graph_rep link;
00200         link.me=pNode->nodeId_;
00201         link.left_child = link.right_child = -1;
00202 
00203         if (pNode)
00204         {
00205             if (pNode->left_child_)
00206             {
00207                 stack.push_back(pNode->left_child_);
00208                 pNode->left_child_->nodeId_= ++nodeId;
00209                 link.left_child=nodeId;
00210             }
00211             if (pNode->right_child_)
00212             {
00213                 stack.push_back(pNode->right_child_);
00214                 pNode->right_child_->nodeId_= ++nodeId;
00215                 link.right_child=nodeId;
00216             }
00217 
00218             arcs.push_back(link);
00219         }
00220     }
00221 
00222     unsigned N=outlist.size();
00223     vsl_b_write(bfs,N);
00224 
00225     vcl_deque<clsfy_binary_tree_node*>::iterator outIter=outlist.begin();
00226     vcl_deque<clsfy_binary_tree_node*>::iterator outIterEnd=outlist.end();
00227     while (outIter != outIterEnd)
00228     {
00229         clsfy_binary_tree_node* pNode=*outIter;
00230         vsl_b_write(bfs,pNode->nodeId_);
00231         pNode->op_.b_write(bfs);
00232         vsl_b_write(bfs,pNode->prob_);
00233         ++outIter;
00234     }
00235 
00236     //Now write out the links graph
00237     N=arcs.size();
00238     vsl_b_write(bfs,N);
00239 
00240     vcl_vector<graph_rep>::iterator arcIter=arcs.begin();
00241     vcl_vector<graph_rep>::iterator arcIterEnd=arcs.end();
00242 
00243     while (arcIter != arcIterEnd)
00244     {
00245         vsl_b_write(bfs,arcIter->me);
00246         vsl_b_write(bfs,arcIter->left_child);
00247         vsl_b_write(bfs,arcIter->right_child);
00248         ++arcIter;
00249     }
00250 }
00251 
00252 //=======================================================================
00253 
00254 void clsfy_binary_tree::b_read(vsl_b_istream& bfs)
00255 {
00256     if (!bfs) return;
00257 
00258     remove_tree(root_);
00259     root_=0;
00260 
00261     short version;
00262     vsl_b_read(bfs,version);
00263     switch (version)
00264     {
00265         case (1):
00266         {
00267             vcl_map<int,clsfy_binary_tree_node*> workmap;
00268             vcl_vector<graph_rep> arcs;
00269 
00270             clsfy_binary_tree_node* pNull=0;
00271             unsigned N;
00272             vsl_b_read(bfs,N);
00273             for (unsigned i=0;i<N;++i)
00274             {
00275                 int nodeId=-1;
00276                 vsl_b_read(bfs,nodeId);
00277                 clsfy_binary_tree_op op;
00278                 op.b_read(bfs);
00279                 clsfy_binary_tree_node* pNode=new clsfy_binary_tree_node(pNull,op);
00280                 pNode->nodeId_=nodeId;
00281                 vsl_b_read(bfs,pNode->prob_);
00282                 workmap[nodeId]=pNode;
00283             }
00284             vsl_b_read(bfs,N);
00285             arcs.reserve(N);
00286             for (unsigned i=0;i<N;++i)
00287             {
00288                 graph_rep link;
00289                 vsl_b_read(bfs,link.me);
00290                 vsl_b_read(bfs,link.left_child);
00291                 vsl_b_read(bfs,link.right_child);
00292                 arcs.push_back(link);
00293             }
00294             root_=workmap[0];
00295             for (unsigned i=0;i<N;++i)
00296             {
00297                 graph_rep link=arcs[i];
00298                 if (link.me!= -1)
00299                 {
00300                     clsfy_binary_tree_node* parent=workmap[link.me];
00301                     clsfy_binary_tree_node* left_child=0;
00302                     clsfy_binary_tree_node* right_child=0;
00303                     if (link.left_child != -1)
00304                         left_child=workmap[link.left_child];
00305                     if (link.right_child != -1)
00306                         right_child=workmap[link.right_child];
00307 
00308                     if (!parent || parent->nodeId_ != link.me)
00309                     {
00310                         vcl_cerr<<"ERROR - Inconsistent parent in tree set up in clsfy_binary_tree::b_read\n";
00311                         assert(0);
00312                     }
00313                     if ((link.left_child != -1) &&
00314                         (!left_child || left_child->nodeId_ != link.left_child))
00315                                         {
00316                         vcl_cerr<<"ERROR - Inconsistent left child in tree set up in clsfy_binary_tree::b_read\n";
00317                         assert(0);
00318                     }
00319                     if ((link.right_child != -1) &&
00320                         (!right_child || right_child->nodeId_ != link.right_child))
00321                                         {
00322                         vcl_cerr<<"ERROR - Inconsistent right child in tree set up in clsfy_binary_tree::b_read\n";
00323                         assert(0);
00324                     }
00325 
00326                     //And link these into the tree
00327                     parent->left_child_=left_child;
00328                     if (left_child)
00329                         left_child->parent_=parent;
00330 
00331                     parent->right_child_=right_child;
00332                     if (right_child)
00333                         right_child->parent_=parent;
00334                 }
00335             }
00336 
00337             //Validate the tree
00338             assert(root_);
00339             vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIter =workmap.begin();
00340             vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIterEnd =workmap.end();
00341             while (nodeIter != nodeIterEnd)
00342             {
00343                 clsfy_binary_tree_node* pNode=nodeIter->second;
00344                 assert(pNode->nodeId_==nodeIter->first);
00345                 if (pNode != root_)
00346                 {
00347                     assert(pNode->parent_);
00348                     assert(pNode->parent_->left_child_==pNode ||
00349                            pNode->parent_->right_child_ == pNode);
00350                 }
00351                 if (pNode->left_child_)
00352                     assert(pNode->left_child_->parent_==pNode);
00353                 if (pNode->right_child_)
00354                     assert(pNode->right_child_->parent_==pNode);
00355 
00356                 //Check all nodes connect back up to root
00357                 while (pNode->parent_)
00358                 {
00359                     pNode=pNode->parent_;
00360                 }
00361                 assert(pNode==root_);
00362 
00363                 ++nodeIter;
00364             }
00365         }
00366         break;
00367 
00368         default:
00369             vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00370                      << "           Unknown version number "<< version << '\n';
00371             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00372     }
00373 }
00374 
00375 clsfy_binary_tree::~clsfy_binary_tree()
00376 {
00377     remove_tree(root_);
00378     root_=0;
00379 }
00380 
00381 void  clsfy_binary_tree::remove_tree(clsfy_binary_tree_node* root)
00382 {
00383     vcl_deque<clsfy_binary_tree_node*> stack;
00384     vcl_deque<clsfy_binary_tree_node*> killset;
00385     stack.push_back(root);
00386     while (!stack.empty())
00387     {
00388         clsfy_binary_tree_node* pNode=stack.front();
00389         stack.pop_front();
00390 
00391         if (pNode)
00392         {
00393             killset.push_back(pNode);
00394             if (pNode->left_child_)
00395             {
00396                 stack.push_back(pNode->left_child_);
00397             }
00398             if (pNode->right_child_)
00399             {
00400                 stack.push_back(pNode->right_child_);
00401             }
00402         }
00403     }
00404 
00405     mbl_stl_clean(killset.begin(),killset.end());
00406 }
00407 
00408 void clsfy_binary_tree::set_root(  clsfy_binary_tree_node* root)
00409 {
00410     if ((root != root_) && root_)
00411         remove_tree(root_);
00412     root_=root;
00413 }
00414 
00415 
00416 //--------------- HELPER CLASSES---------------------------------------------------------
00417 
00418 clsfy_binary_tree_node* clsfy_binary_tree_node::create_child(const clsfy_binary_tree_op& op)
00419 {
00420     return new clsfy_binary_tree_node(this,op);
00421 }
00422 
00423 void clsfy_binary_tree_op::b_write(vsl_b_ostream& bfs) const
00424 {
00425     vsl_b_write(bfs,version_no());
00426     vsl_b_write(bfs,data_index_);
00427     vsl_b_write(bfs,classifier_);
00428 }
00429 
00430 //: Load the class from a Binary File Stream
00431 void clsfy_binary_tree_op::b_read(vsl_b_istream& bfs)
00432 {
00433     short version;
00434     vsl_b_read(bfs,version);
00435     if (version != 1)
00436     {
00437         vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00438                  << "           Unknown version number "<< version << '\n';
00439         bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00440     }
00441     else
00442     {
00443         vsl_b_read(bfs,data_index_);
00444         vsl_b_read(bfs,classifier_);
00445     }
00446 }
00447