Go to the documentation of this file.00001
00002 #include "clsfy_binary_tree.h"
00003
00004
00005
00006
00007
00008 #include <vcl_string.h>
00009 #include <vcl_deque.h>
00010 #include <vcl_algorithm.h>
00011 #include <vcl_iterator.h>
00012 #include <vcl_cmath.h>
00013 #include <vcl_cassert.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_vector_io.h>
00016 #include <vnl/io/vnl_io_vector.h>
00017 #include <mbl/mbl_stl.h>
00018
00019
00020 clsfy_binary_tree::clsfy_binary_tree(const clsfy_binary_tree& srcTree)
00021 {
00022 root_=cache_node_=0;
00023 copy(srcTree);
00024 }
00025
00026 clsfy_binary_tree& clsfy_binary_tree::operator=(const clsfy_binary_tree& srcTree)
00027 {
00028 if (&srcTree != this)
00029 {
00030 copy(srcTree);
00031 }
00032 return *this;
00033 }
00034
00035 void clsfy_binary_tree::copy(const clsfy_binary_tree& srcTree)
00036 {
00037 remove_tree(root_);
00038
00039 if (srcTree.root_)
00040 {
00041 root_ = new clsfy_binary_tree_node(0,srcTree.root_->op_);
00042 root_->prob_ = srcTree.root_->prob_;
00043 copy_children(srcTree.root_,root_);
00044 }
00045 else
00046 root_=0;
00047 cache_node_ = root_;
00048 }
00049
00050 void clsfy_binary_tree::copy_children(clsfy_binary_tree_node* pSrcNode,clsfy_binary_tree_node* pNode)
00051 {
00052 bool left=true;
00053 pNode->prob_ = pSrcNode->prob_;
00054 if (pSrcNode->left_child_)
00055 {
00056 pNode->add_child(pSrcNode->left_child_->op_,left);
00057 copy_children(pSrcNode->left_child_,
00058 pNode->left_child_);
00059 }
00060 if (pSrcNode->right_child_)
00061 {
00062 pNode->add_child(pSrcNode->right_child_->op_,!left);
00063 copy_children(pSrcNode->right_child_,
00064 pNode->right_child_);
00065 }
00066 }
00067
00068
00069
00070 unsigned clsfy_binary_tree::classify(const vnl_vector<double> &input) const
00071 {
00072 unsigned outClass=0;
00073
00074 clsfy_binary_tree_node* pNode=root_;
00075 if (!pNode)
00076 {
00077 vcl_cerr<<"WARNING - empty tree in clsfy_binary_tree::classify\n"
00078 <<"Return default classification zero\n";
00079 return 0;
00080 }
00081 clsfy_binary_tree_node* pChild=0;
00082 do
00083 {
00084 pNode->op_.set_data(input);
00085 unsigned indicator=pNode->op_.classify();
00086 if (indicator==0)
00087 {
00088 pChild=pNode->left_child_;
00089 }
00090 else
00091 {
00092 pChild=pNode->right_child_;
00093 }
00094 if (pChild)
00095 pNode=pChild;
00096 else
00097 {
00098 cache_node_ = pNode;
00099 outClass=(pNode->prob_>0.5 ? 1 : 0);
00100 }
00101 }while (pChild);
00102
00103 return outClass;
00104 }
00105
00106
00107
00108
00109 void clsfy_binary_tree::class_probabilities(vcl_vector<double>& outputs,
00110 vnl_vector<double>const& input) const
00111 {
00112 outputs.resize(1);
00113 classify(input);
00114 outputs[0] = cache_node_->prob_;
00115 }
00116
00117
00118
00119
00120 unsigned clsfy_binary_tree::n_dims() const
00121 {
00122 clsfy_binary_tree_node* pNode=root_;
00123 if (pNode)
00124 return pNode->op_.ndims();
00125 else
00126 return 0;
00127 }
00128
00129
00130
00131
00132 double clsfy_binary_tree::log_l(const vnl_vector<double> &input) const
00133 {
00134 vcl_vector<double > probs;
00135 class_probabilities(probs,input);
00136 double p1=probs[0];
00137 double p0=1-p1;
00138 const double epsilon=1.0E-8;
00139 if (p0<epsilon) p0=epsilon;
00140 double L=vcl_log(p1/p0);
00141
00142 return L;
00143 }
00144
00145
00146
00147
00148 vcl_string clsfy_binary_tree::is_a() const
00149 {
00150 return vcl_string("clsfy_binary_tree");
00151 }
00152
00153
00154
00155 bool clsfy_binary_tree::is_class(vcl_string const& s) const
00156 {
00157 return s == clsfy_binary_tree::is_a() || clsfy_classifier_base::is_class(s);
00158 }
00159
00160
00161
00162 short clsfy_binary_tree::version_no() const
00163 {
00164 return 1;
00165 }
00166
00167
00168
00169 clsfy_classifier_base* clsfy_binary_tree::clone() const
00170 {
00171 return new clsfy_binary_tree(*this);
00172 }
00173
00174
00175
00176 void clsfy_binary_tree::print_summary(vcl_ostream& ) const
00177 {
00178 }
00179
00180
00181
00182 void clsfy_binary_tree::b_write(vsl_b_ostream& bfs) const
00183 {
00184 vsl_b_write(bfs,version_no());
00185 int nodeId=0;
00186
00187 vcl_deque<clsfy_binary_tree_node*> stack;
00188 vcl_deque<clsfy_binary_tree_node*> outlist;
00189 vcl_vector<graph_rep> arcs;
00190 clsfy_binary_tree_node* pNode=root_;
00191
00192 stack.push_back(pNode);
00193 pNode->nodeId_=0;
00194 while (!stack.empty())
00195 {
00196 pNode=stack.front();
00197 stack.pop_front();
00198 outlist.push_back(pNode);
00199 graph_rep link;
00200 link.me=pNode->nodeId_;
00201 link.left_child = link.right_child = -1;
00202
00203 if (pNode)
00204 {
00205 if (pNode->left_child_)
00206 {
00207 stack.push_back(pNode->left_child_);
00208 pNode->left_child_->nodeId_= ++nodeId;
00209 link.left_child=nodeId;
00210 }
00211 if (pNode->right_child_)
00212 {
00213 stack.push_back(pNode->right_child_);
00214 pNode->right_child_->nodeId_= ++nodeId;
00215 link.right_child=nodeId;
00216 }
00217
00218 arcs.push_back(link);
00219 }
00220 }
00221
00222 unsigned N=outlist.size();
00223 vsl_b_write(bfs,N);
00224
00225 vcl_deque<clsfy_binary_tree_node*>::iterator outIter=outlist.begin();
00226 vcl_deque<clsfy_binary_tree_node*>::iterator outIterEnd=outlist.end();
00227 while (outIter != outIterEnd)
00228 {
00229 clsfy_binary_tree_node* pNode=*outIter;
00230 vsl_b_write(bfs,pNode->nodeId_);
00231 pNode->op_.b_write(bfs);
00232 vsl_b_write(bfs,pNode->prob_);
00233 ++outIter;
00234 }
00235
00236
00237 N=arcs.size();
00238 vsl_b_write(bfs,N);
00239
00240 vcl_vector<graph_rep>::iterator arcIter=arcs.begin();
00241 vcl_vector<graph_rep>::iterator arcIterEnd=arcs.end();
00242
00243 while (arcIter != arcIterEnd)
00244 {
00245 vsl_b_write(bfs,arcIter->me);
00246 vsl_b_write(bfs,arcIter->left_child);
00247 vsl_b_write(bfs,arcIter->right_child);
00248 ++arcIter;
00249 }
00250 }
00251
00252
00253
00254 void clsfy_binary_tree::b_read(vsl_b_istream& bfs)
00255 {
00256 if (!bfs) return;
00257
00258 remove_tree(root_);
00259 root_=0;
00260
00261 short version;
00262 vsl_b_read(bfs,version);
00263 switch (version)
00264 {
00265 case (1):
00266 {
00267 vcl_map<int,clsfy_binary_tree_node*> workmap;
00268 vcl_vector<graph_rep> arcs;
00269
00270 clsfy_binary_tree_node* pNull=0;
00271 unsigned N;
00272 vsl_b_read(bfs,N);
00273 for (unsigned i=0;i<N;++i)
00274 {
00275 int nodeId=-1;
00276 vsl_b_read(bfs,nodeId);
00277 clsfy_binary_tree_op op;
00278 op.b_read(bfs);
00279 clsfy_binary_tree_node* pNode=new clsfy_binary_tree_node(pNull,op);
00280 pNode->nodeId_=nodeId;
00281 vsl_b_read(bfs,pNode->prob_);
00282 workmap[nodeId]=pNode;
00283 }
00284 vsl_b_read(bfs,N);
00285 arcs.reserve(N);
00286 for (unsigned i=0;i<N;++i)
00287 {
00288 graph_rep link;
00289 vsl_b_read(bfs,link.me);
00290 vsl_b_read(bfs,link.left_child);
00291 vsl_b_read(bfs,link.right_child);
00292 arcs.push_back(link);
00293 }
00294 root_=workmap[0];
00295 for (unsigned i=0;i<N;++i)
00296 {
00297 graph_rep link=arcs[i];
00298 if (link.me!= -1)
00299 {
00300 clsfy_binary_tree_node* parent=workmap[link.me];
00301 clsfy_binary_tree_node* left_child=0;
00302 clsfy_binary_tree_node* right_child=0;
00303 if (link.left_child != -1)
00304 left_child=workmap[link.left_child];
00305 if (link.right_child != -1)
00306 right_child=workmap[link.right_child];
00307
00308 if (!parent || parent->nodeId_ != link.me)
00309 {
00310 vcl_cerr<<"ERROR - Inconsistent parent in tree set up in clsfy_binary_tree::b_read\n";
00311 assert(0);
00312 }
00313 if ((link.left_child != -1) &&
00314 (!left_child || left_child->nodeId_ != link.left_child))
00315 {
00316 vcl_cerr<<"ERROR - Inconsistent left child in tree set up in clsfy_binary_tree::b_read\n";
00317 assert(0);
00318 }
00319 if ((link.right_child != -1) &&
00320 (!right_child || right_child->nodeId_ != link.right_child))
00321 {
00322 vcl_cerr<<"ERROR - Inconsistent right child in tree set up in clsfy_binary_tree::b_read\n";
00323 assert(0);
00324 }
00325
00326
00327 parent->left_child_=left_child;
00328 if (left_child)
00329 left_child->parent_=parent;
00330
00331 parent->right_child_=right_child;
00332 if (right_child)
00333 right_child->parent_=parent;
00334 }
00335 }
00336
00337
00338 assert(root_);
00339 vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIter =workmap.begin();
00340 vcl_map<int,clsfy_binary_tree_node*>::iterator nodeIterEnd =workmap.end();
00341 while (nodeIter != nodeIterEnd)
00342 {
00343 clsfy_binary_tree_node* pNode=nodeIter->second;
00344 assert(pNode->nodeId_==nodeIter->first);
00345 if (pNode != root_)
00346 {
00347 assert(pNode->parent_);
00348 assert(pNode->parent_->left_child_==pNode ||
00349 pNode->parent_->right_child_ == pNode);
00350 }
00351 if (pNode->left_child_)
00352 assert(pNode->left_child_->parent_==pNode);
00353 if (pNode->right_child_)
00354 assert(pNode->right_child_->parent_==pNode);
00355
00356
00357 while (pNode->parent_)
00358 {
00359 pNode=pNode->parent_;
00360 }
00361 assert(pNode==root_);
00362
00363 ++nodeIter;
00364 }
00365 }
00366 break;
00367
00368 default:
00369 vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00370 << " Unknown version number "<< version << '\n';
00371 bfs.is().clear(vcl_ios::badbit);
00372 }
00373 }
00374
00375 clsfy_binary_tree::~clsfy_binary_tree()
00376 {
00377 remove_tree(root_);
00378 root_=0;
00379 }
00380
00381 void clsfy_binary_tree::remove_tree(clsfy_binary_tree_node* root)
00382 {
00383 vcl_deque<clsfy_binary_tree_node*> stack;
00384 vcl_deque<clsfy_binary_tree_node*> killset;
00385 stack.push_back(root);
00386 while (!stack.empty())
00387 {
00388 clsfy_binary_tree_node* pNode=stack.front();
00389 stack.pop_front();
00390
00391 if (pNode)
00392 {
00393 killset.push_back(pNode);
00394 if (pNode->left_child_)
00395 {
00396 stack.push_back(pNode->left_child_);
00397 }
00398 if (pNode->right_child_)
00399 {
00400 stack.push_back(pNode->right_child_);
00401 }
00402 }
00403 }
00404
00405 mbl_stl_clean(killset.begin(),killset.end());
00406 }
00407
00408 void clsfy_binary_tree::set_root( clsfy_binary_tree_node* root)
00409 {
00410 if ((root != root_) && root_)
00411 remove_tree(root_);
00412 root_=root;
00413 }
00414
00415
00416
00417
00418 clsfy_binary_tree_node* clsfy_binary_tree_node::create_child(const clsfy_binary_tree_op& op)
00419 {
00420 return new clsfy_binary_tree_node(this,op);
00421 }
00422
00423 void clsfy_binary_tree_op::b_write(vsl_b_ostream& bfs) const
00424 {
00425 vsl_b_write(bfs,version_no());
00426 vsl_b_write(bfs,data_index_);
00427 vsl_b_write(bfs,classifier_);
00428 }
00429
00430
00431 void clsfy_binary_tree_op::b_read(vsl_b_istream& bfs)
00432 {
00433 short version;
00434 vsl_b_read(bfs,version);
00435 if (version != 1)
00436 {
00437 vcl_cerr << "I/O ERROR: clsfy_binary_tree::b_read(vsl_b_istream&)\n"
00438 << " Unknown version number "<< version << '\n';
00439 bfs.is().clear(vcl_ios::badbit);
00440 }
00441 else
00442 {
00443 vsl_b_read(bfs,data_index_);
00444 vsl_b_read(bfs,classifier_);
00445 }
00446 }
00447