contrib/mul/clsfy/clsfy_binary_tree_builder.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_tree_builder.h
00002 #ifndef clsfy_binary_tree_builder_h_
00003 #define clsfy_binary_tree_builder_h_
00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00005 #pragma interface
00006 #endif
00007 //:
00008 // \file
00009 // \brief Build a binary tree classifier
00010 // \author Martin Roberts
00011 
00012 #include <clsfy/clsfy_builder_base.h>
00013 #include <clsfy/clsfy_binary_tree.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_set.h>
00016 #include <vcl_string.h>
00017 #include <vcl_iosfwd.h>
00018 #include <mbl/mbl_data_wrapper.h>
00019 #include <vnl/vnl_vector.h>
00020 #include <vnl/vnl_random.h>
00021 
00022 
00023 class clsfy_binary_tree_bnode :public  clsfy_binary_tree_node
00024 {
00025     //Similar to classifiers tree node but the builder also needs
00026     //to keep track of relevant data subsets at each node
00027     vcl_set<unsigned> subIndicesL;
00028     vcl_set<unsigned> subIndicesR;
00029 
00030 
00031   clsfy_binary_tree_bnode(clsfy_binary_tree_node* parent,
00032                           const clsfy_binary_tree_op& op):
00033     clsfy_binary_tree_node(parent,op) {}
00034 
00035     virtual clsfy_binary_tree_node* create_child(const clsfy_binary_tree_op& op);
00036 
00037     //Note the owning classifier removes the tree - beware as once deleted its children
00038     //may be inaccessible for deletion
00039     virtual ~clsfy_binary_tree_bnode();
00040 
00041     friend class clsfy_binary_tree_builder;
00042 };
00043 
00044 
00045 //: Builds clsfy_binary_tree classifiers
00046 // Keep finding the variable split that gives the least min_error for
00047 // a binary threshold. Divide up the dataset by that and keep recursively
00048 // building binary threshold classifiers in a tree structure till either
00049 // Max depth level reached, or a node is pure, or node's data <min_nide_size
00050 
00051 class clsfy_binary_tree_builder : public clsfy_builder_base
00052 {
00053     //: The max depth of any leaf node in the tree
00054     //If negative no max is applied, and all final leaf nodes are pure
00055     //(i.e. single class)
00056     int max_depth_;
00057 
00058     //: Minimum number of points associated with any node
00059     // If negative this is ignored, otherwise if a split would produce a child
00060     // node less than this, then the split does not occur and the branch is
00061     // terminated
00062     int min_node_size_;
00063 
00064     //: Set this for random forest behaviour
00065     //At each split the selection is only from a random subset of this size
00066     //If negative (default) it is ignored and all are used
00067     int nbranch_params_;
00068 
00069     //: Work space for randomising params (NB not thread safe)
00070     mutable vcl_vector<unsigned > base_indices_;
00071 
00072   public:
00073     // Dflt ctor
00074     clsfy_binary_tree_builder();
00075 
00076     //: Create empty model
00077     // Caller is responsible for deletion
00078     virtual clsfy_classifier_base* new_classifier() const;
00079 
00080     //: Build classifier from data
00081     // return the mean error over the training set.
00082     virtual double build(clsfy_classifier_base& classifier,
00083                          mbl_data_wrapper<vnl_vector<double> >& inputs,
00084                          unsigned nClasses,
00085                          const vcl_vector<unsigned> &outputs) const;
00086 
00087     //: Name of the class
00088     virtual vcl_string is_a() const;
00089 
00090     //: Name of the class
00091     virtual bool is_class(vcl_string const& s) const;
00092 
00093     //: IO Version number
00094     short version_no() const;
00095 
00096     //: Create a copy on the heap and return base class pointer
00097     virtual clsfy_builder_base* clone() const;
00098 
00099     //: Print class to os
00100     virtual void print_summary(vcl_ostream& os) const;
00101 
00102     //: Save class to binary file stream
00103     virtual void b_write(vsl_b_ostream& bfs) const;
00104 
00105     //: Load class from binary file stream
00106     virtual void b_read(vsl_b_istream& bfs);
00107 
00108     //: The max tree depth (default -1 means no max set )
00109     int max_depth() const {return max_depth_;}
00110 
00111     //: Set the number of nearest neighbours to look for.
00112     // If not see default is high value to force continuation till
00113     // all final leaf nodes are pure (i.e. single class)
00114     // If set negative the value is ignored
00115     void set_max_depth(int max_depth) {max_depth_=max_depth;}
00116 
00117     int min_node_size() const {return min_node_size_;}
00118 
00119     //: Set minimum number of points associated with any node
00120     // If negative this is ignored, otherwise if a split would produce a child
00121     // node less than this, then the split does not occur and the branch is
00122     // terminated
00123     void set_min_node_size(int min_node_size) {min_node_size_=min_node_size;}
00124 
00125     //: Set this for random forest behaviour
00126     // At each split the selection is only from a random subset of this size
00127     // If negative then it is ignored
00128     void set_nbranch_params(int nbranch_params) {nbranch_params_ = nbranch_params;}
00129 
00130     //: set whether the build calculates a test error over the input training set
00131     // Default is on, but this can be turned off e.g. for a random forest of
00132     // many child trees
00133     void set_calc_test_error(bool on) {calc_test_error_=on;}
00134 
00135     //: Seed the sample used to select branching parameter subsets
00136     void seed_sampler(unsigned long seed);
00137   protected:
00138     //: Randomly select  the ndimsUsed dimensions for current branch
00139     // Return indices of selected parameters
00140     // Best of these is then chosen as the branch
00141     virtual void randomise_parameters(unsigned ndimsUsed,
00142                                       vcl_vector<unsigned  >& param_indices) const;
00143 
00144     mutable  vnl_random random_sampler_;
00145 
00146 
00147   private:
00148     void build_children(
00149         const vcl_vector<vnl_vector<double> >& vin,
00150         const vcl_vector<unsigned>& outputs,
00151         clsfy_binary_tree_bnode* parent, bool to_left) const;
00152 
00153     void copy_children(clsfy_binary_tree_bnode* pBuilderNode,clsfy_binary_tree_node* pNode) const;
00154 
00155     void set_node_prob(clsfy_binary_tree_node* pNode,
00156                        clsfy_binary_tree_bnode* pBuilderNode) const ;
00157 
00158     void build_a_node(
00159         const vcl_vector<vnl_vector<double> >& vin,
00160         const vcl_vector<unsigned>& outputs,
00161         const vcl_set<unsigned >& subIndices,
00162         clsfy_binary_tree_bnode* pNode) const;
00163 
00164     bool isNodePure(const vcl_set<unsigned >& subIndices,
00165                     const vcl_vector<unsigned>& outputs) const;
00166 
00167     void add_terminator(
00168         const vcl_vector<vnl_vector<double> >& vin,
00169         const vcl_vector<unsigned>& outputs,
00170         clsfy_binary_tree_bnode* parent,
00171         bool to_left, bool pure) const;
00172 
00173     bool calc_test_error_;
00174 };
00175 
00176 
00177 #endif // clsfy_binary_tree_builder_h_