contrib/mul/clsfy/clsfy_random_forest_builder.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_random_forest_builder.h
00002 #ifndef clsfy_random_forest_builder_h_
00003 #define clsfy_random_forest_builder_h_
00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00005 #pragma interface
00006 #endif
00007 //:
00008 // \file
00009 // \brief Build a random forest classifier
00010 // \author Martin Roberts
00011 
00012 #include <clsfy/clsfy_builder_base.h>
00013 #include <clsfy/clsfy_random_forest.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_set.h>
00016 #include <vcl_string.h>
00017 #include <vcl_iosfwd.h>
00018 #include <vnl/vnl_vector.h>
00019 #include <vnl/vnl_random.h>
00020 
00021 #include <mbl/mbl_data_wrapper.h>
00022 
00023 
00024 //: Builds clsfy_random_forest classifiers
00025 class clsfy_random_forest_builder : public clsfy_builder_base
00026 {
00027  public:
00028   // Dflt ctor
00029   clsfy_random_forest_builder();
00030 
00031   clsfy_random_forest_builder(unsigned ntrees,
00032                               int max_depth=-1,int min_node_size=-1);
00033   virtual ~clsfy_random_forest_builder();
00034 
00035   //: Create empty model
00036   // Caller is responsible for deletion
00037   virtual clsfy_classifier_base* new_classifier() const;
00038 
00039   //: Build classifier from data
00040   // return the mean error over the training set.
00041   virtual double build(clsfy_classifier_base& classifier,
00042                        mbl_data_wrapper<vnl_vector<double> >& inputs,
00043                        unsigned nClasses,
00044                        const vcl_vector<unsigned> &outputs) const;
00045 
00046   //: Name of the class
00047   virtual vcl_string is_a() const;
00048 
00049   //: Name of the class
00050   virtual bool is_class(vcl_string const& s) const;
00051 
00052   //: IO Version number
00053   short version_no() const;
00054 
00055   //: Create a copy on the heap and return base class pointer
00056   virtual clsfy_builder_base* clone() const;
00057 
00058   //: Print class to os
00059   virtual void print_summary(vcl_ostream& os) const;
00060 
00061   //: Save class to binary file stream
00062   virtual void b_write(vsl_b_ostream& bfs) const;
00063 
00064   //: Load class from binary file stream
00065   virtual void b_read(vsl_b_istream& bfs);
00066 
00067   //: The max tree depth (default -1 means no max set )
00068   int max_depth() const {return max_depth_;}
00069 
00070   //: Set the number of nearest neighbours to look for.
00071   // If not see default is high value to force continuation till
00072   // all final leaf nodes are pure (i.e. single class)
00073   // If set negative the value is ignored
00074   void set_max_depth(int max_depth) {max_depth_=max_depth;}
00075 
00076   int min_node_size() const {return min_node_size_;}
00077 
00078   //: Set minimum number of points associated with any node
00079   // If negative this is ignored, otherwise if a split would produce a child
00080   // node less than this, then the split does not occur and the branch is
00081   // terminated
00082   void set_min_node_size(int min_node_size) {min_node_size_=min_node_size;}
00083 
00084 
00085   //: set number of trees in forest
00086   // Note this must be set before calling build
00087   // Default is 100
00088   void set_ntrees(unsigned ntrees) {ntrees_=ntrees;}
00089 
00090   unsigned ntrees() const {return ntrees_;}
00091 
00092   virtual void seed_sampler(unsigned long seed);
00093 
00094   //: set whether the build calculates a test error over the input training set
00095   // Default is on, but this can be turned off
00096   // e.g. for a parallel build of many partial random forests of
00097   // which can be later merged
00098   void set_calc_test_error(bool on) {calc_test_error_=on;}
00099 
00100   //: Save a pointer to storage for out of bag indices
00101   void set_oob_indices( vcl_vector<vcl_vector<unsigned > >* poobIndices)
00102   {poob_indices_=poobIndices;}
00103 
00104  protected:
00105   //: Pick the number of parameters that the tree builder branches on
00106   // Default uses sqrt of ndims
00107   virtual unsigned select_nbranch_params(unsigned ndims) const;
00108 
00109   //: Pick a random data subset (with replacement)
00110   virtual void select_data(vcl_vector<vnl_vector<double> >& inputs,
00111                            const vcl_vector<unsigned> &outputs,
00112                            vcl_vector<vnl_vector<double> >& bootstrapped_inputs,
00113                            vcl_vector<unsigned> & bootstrapped_outputs) const;
00114 
00115   virtual unsigned long get_tree_builder_seed() const;
00116 
00117   //: Number of trees
00118   unsigned ntrees_;
00119   //: The max depth of any child tree
00120   //If negative no max is applied, and all final leaf nodes are pure
00121   //(i.e. single class)
00122   int max_depth_;
00123 
00124 
00125   //: Minimum number of points associated with any node
00126   // If negative this is ignored, otherwise if a split would produce a child
00127   // node less than this, then the split does not occur and the branch is
00128   // terminated
00129   int min_node_size_;
00130 
00131   //: Uniform sampler on 0,1 (for bootstrapping)
00132   mutable  vnl_random random_sampler_;
00133 
00134   //: Pointer to storage of point indices for each bootstrapped tree
00135   // Can be used for out of bag estimates
00136   // Saves for tree i the indices of all points used in its training
00137   // Note the storage is supplied from outside this class, as this is a kind of bolt-on
00138   vcl_vector<vcl_vector<unsigned > >* poob_indices_;
00139  private:
00140   //: Does the builder calculate the error on the training set?
00141   bool calc_test_error_;
00142 };
00143 
00144 
00145 #endif // clsfy_random_forest_builder_h_