contrib/mul/clsfy/clsfy_binary_tree.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree.h
00002 #ifndef clsfy_binary_tree_h_
00003 #define clsfy_binary_tree_h_
00004 //:
00005 // \file
00006 // \brief Binary tree classifier
00007 // \author Martin Roberts
00008 #include <clsfy/clsfy_classifier_base.h>
00009 #include <clsfy/clsfy_binary_threshold_1d.h>
00010 #include <vcl_iosfwd.h>
00011 
00012 
00013 //: One node of a binary tree classifier - wrapper round clsfy_binary_threshold_1d
00014 //  Needs also to store the data feature index associated with the node
00015 //  Then it calls its binary classifier for that node
00016 //  Returns class zero if s_*x[i]<threshold_
00017 
00018 class clsfy_binary_tree_op
00019 {
00020  protected:
00021     //Index within data of variable used at this node (set to -1 if none assigned)
00022     int data_index_;
00023     const vnl_vector<double>* data_ptr_;
00024     clsfy_binary_threshold_1d classifier_;
00025 
00026  public:
00027 
00028   clsfy_binary_tree_op() : data_index_(-1), data_ptr_(0) {}
00029   clsfy_binary_tree_op(const vnl_vector<double>* data_ptr,
00030                        int data_index=-1)
00031     : data_index_(data_index), data_ptr_(data_ptr) {}
00032 
00033   clsfy_binary_threshold_1d& classifier() {return classifier_;}
00034   unsigned data_index() const {return data_index_;}
00035   void set_data_index(unsigned index) {data_index_=index;}
00036   void set_data_ptr(const vnl_vector<double>* data_ptr) {data_ptr_= data_ptr;}
00037 
00038   //: Return reference to data - NB throws std::bad_cast if null
00039   const vnl_vector<double >& data() const {return *data_ptr_;}
00040 
00041   void set_data(const vnl_vector<double >& inputs) {data_ptr_=&inputs;}
00042   //: Return value
00043   double val() const {return (*data_ptr_)[data_index_];}
00044 
00045   //: Classify
00046   unsigned classify() {return classifier_.classify(val());}
00047 
00048   unsigned ndims() {return data_ptr_ ? data_ptr_->size() : 0;}
00049 
00050   //: Save class to a binary File Stream
00051   void b_write(vsl_b_ostream& bfs) const;
00052 
00053   //: Load the class from a Binary File Stream
00054   void b_read(vsl_b_istream& bfs);
00055 
00056   short version_no() const {return 1;}
00057 };
00058 
00059 
00060 class clsfy_binary_tree_node
00061 {
00062   int nodeId_;
00063   clsfy_binary_tree_node* parent_;
00064   clsfy_binary_tree_node* left_child_;
00065   clsfy_binary_tree_node* right_child_;
00066   clsfy_binary_tree_op op_;
00067   double prob_; //Only used on terminal nodes
00068  public:
00069 
00070   clsfy_binary_tree_node(clsfy_binary_tree_node* parent,
00071                          const clsfy_binary_tree_op& op)
00072   : nodeId_(-1),parent_(parent),left_child_(0),right_child_(0),op_(op),prob_(0.5) {}
00073 
00074   virtual clsfy_binary_tree_node* create_child(const clsfy_binary_tree_op& op);
00075   void add_child(const clsfy_binary_tree_op& op,bool bLeft)
00076   {
00077     clsfy_binary_tree_node* child=create_child(op);
00078     if (bLeft)
00079       left_child_=child;
00080     else
00081       right_child_=child;
00082   }
00083 
00084   //Note the owning classifier removes the tree - beware as once deleted its children
00085   //may be inaccessible for deletion
00086   virtual ~clsfy_binary_tree_node() {}
00087 
00088   friend class clsfy_binary_tree;
00089   friend class clsfy_binary_tree_builder;
00090 };
00091 
00092 
00093 //: A binary tree classifier
00094 // Drop down the tree using a binary threshold on a specific variable from the set at each node.
00095 // Branch left for one classification, right for the other
00096 // Eventually a node is reached with no children and that node's
00097 // binary threshold classification is returned
00098 
00099 class clsfy_binary_tree : public clsfy_classifier_base
00100 {
00101  public:
00102 
00103   struct graph_rep
00104   {
00105       int me;
00106       int left_child;
00107       int right_child;
00108   };
00109 
00110   //: Constructor
00111   clsfy_binary_tree(): root_(0),cache_node_(0) {}
00112 
00113   virtual ~clsfy_binary_tree();
00114 
00115   clsfy_binary_tree(const clsfy_binary_tree& srcTree);
00116 
00117   clsfy_binary_tree& operator=(const clsfy_binary_tree& srcTree);
00118 
00119   static void remove_tree(clsfy_binary_tree_node* root);
00120   //: Return the classification of the given probe vector.
00121   virtual unsigned classify(const vnl_vector<double> &input) const;
00122 
00123   //: Provides a probability-like value that the input being in each class.
00124   // output(i) i<nClasses, contains the probability that the input is in class i
00125   virtual void class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const;
00126 
00127   //: This value has properties of a Log likelihood of being in class (binary classifiers only)
00128   // class probability = exp(logL) / (1+exp(logL))
00129   virtual double log_l(const vnl_vector<double> &input) const;
00130 
00131   //: The number of possible output classes.
00132   virtual unsigned n_classes() const {return 1;}
00133 
00134   //: The dimensionality of input vectors.
00135   virtual unsigned n_dims() const;
00136 
00137   //: Storage version number
00138   virtual short version_no() const;
00139 
00140   //: Name of the class
00141   virtual vcl_string is_a() const;
00142 
00143   //: Name of the class
00144   virtual bool is_class(vcl_string const& s) const;
00145 
00146   //: Create a copy on the heap and return base class pointer
00147   virtual clsfy_classifier_base* clone() const;
00148 
00149   //: Print class to os
00150   virtual void print_summary(vcl_ostream& os) const;
00151 
00152   //: Save class to binary file stream
00153   virtual void b_write(vsl_b_ostream& bfs) const;
00154 
00155   //: Load class from binary file stream
00156   virtual void b_read(vsl_b_istream& bfs);
00157 
00158   //: Normally only the builder uses this
00159   void set_root(  clsfy_binary_tree_node* root);
00160  private:
00161   clsfy_binary_tree_node* root_;
00162   mutable clsfy_binary_tree_node* cache_node_;
00163  private:
00164   void copy(const clsfy_binary_tree& srcTree);
00165   void copy_children(clsfy_binary_tree_node* pSrcNode,clsfy_binary_tree_node* pNode);
00166 };
00167 
00168 #endif // clsfy_binary_tree_h_