00001 // This is mul/clsfy/clsfy_random_forest_builder.h 00002 #ifndef clsfy_random_forest_builder_h_ 00003 #define clsfy_random_forest_builder_h_ 00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE 00005 #pragma interface 00006 #endif 00007 //: 00008 // \file 00009 // \brief Build a random forest classifier 00010 // \author Martin Roberts 00011 00012 #include <clsfy/clsfy_builder_base.h> 00013 #include <clsfy/clsfy_random_forest.h> 00014 #include <vcl_vector.h> 00015 #include <vcl_set.h> 00016 #include <vcl_string.h> 00017 #include <vcl_iosfwd.h> 00018 #include <vnl/vnl_vector.h> 00019 #include <vnl/vnl_random.h> 00020 00021 #include <mbl/mbl_data_wrapper.h> 00022 00023 00024 //: Builds clsfy_random_forest classifiers 00025 class clsfy_random_forest_builder : public clsfy_builder_base 00026 { 00027 public: 00028 // Dflt ctor 00029 clsfy_random_forest_builder(); 00030 00031 clsfy_random_forest_builder(unsigned ntrees, 00032 int max_depth=-1,int min_node_size=-1); 00033 virtual ~clsfy_random_forest_builder(); 00034 00035 //: Create empty model 00036 // Caller is responsible for deletion 00037 virtual clsfy_classifier_base* new_classifier() const; 00038 00039 //: Build classifier from data 00040 // return the mean error over the training set. 00041 virtual double build(clsfy_classifier_base& classifier, 00042 mbl_data_wrapper<vnl_vector<double> >& inputs, 00043 unsigned nClasses, 00044 const vcl_vector<unsigned> &outputs) const; 00045 00046 //: Name of the class 00047 virtual vcl_string is_a() const; 00048 00049 //: Name of the class 00050 virtual bool is_class(vcl_string const& s) const; 00051 00052 //: IO Version number 00053 short version_no() const; 00054 00055 //: Create a copy on the heap and return base class pointer 00056 virtual clsfy_builder_base* clone() const; 00057 00058 //: Print class to os 00059 virtual void print_summary(vcl_ostream& os) const; 00060 00061 //: Save class to binary file stream 00062 virtual void b_write(vsl_b_ostream& bfs) const; 00063 00064 //: Load class from binary file stream 00065 virtual void b_read(vsl_b_istream& bfs); 00066 00067 //: The max tree depth (default -1 means no max set ) 00068 int max_depth() const {return max_depth_;} 00069 00070 //: Set the number of nearest neighbours to look for. 00071 // If not see default is high value to force continuation till 00072 // all final leaf nodes are pure (i.e. single class) 00073 // If set negative the value is ignored 00074 void set_max_depth(int max_depth) {max_depth_=max_depth;} 00075 00076 int min_node_size() const {return min_node_size_;} 00077 00078 //: Set minimum number of points associated with any node 00079 // If negative this is ignored, otherwise if a split would produce a child 00080 // node less than this, then the split does not occur and the branch is 00081 // terminated 00082 void set_min_node_size(int min_node_size) {min_node_size_=min_node_size;} 00083 00084 00085 //: set number of trees in forest 00086 // Note this must be set before calling build 00087 // Default is 100 00088 void set_ntrees(unsigned ntrees) {ntrees_=ntrees;} 00089 00090 unsigned ntrees() const {return ntrees_;} 00091 00092 virtual void seed_sampler(unsigned long seed); 00093 00094 //: set whether the build calculates a test error over the input training set 00095 // Default is on, but this can be turned off 00096 // e.g. for a parallel build of many partial random forests of 00097 // which can be later merged 00098 void set_calc_test_error(bool on) {calc_test_error_=on;} 00099 00100 //: Save a pointer to storage for out of bag indices 00101 void set_oob_indices( vcl_vector<vcl_vector<unsigned > >* poobIndices) 00102 {poob_indices_=poobIndices;} 00103 00104 protected: 00105 //: Pick the number of parameters that the tree builder branches on 00106 // Default uses sqrt of ndims 00107 virtual unsigned select_nbranch_params(unsigned ndims) const; 00108 00109 //: Pick a random data subset (with replacement) 00110 virtual void select_data(vcl_vector<vnl_vector<double> >& inputs, 00111 const vcl_vector<unsigned> &outputs, 00112 vcl_vector<vnl_vector<double> >& bootstrapped_inputs, 00113 vcl_vector<unsigned> & bootstrapped_outputs) const; 00114 00115 virtual unsigned long get_tree_builder_seed() const; 00116 00117 //: Number of trees 00118 unsigned ntrees_; 00119 //: The max depth of any child tree 00120 //If negative no max is applied, and all final leaf nodes are pure 00121 //(i.e. single class) 00122 int max_depth_; 00123 00124 00125 //: Minimum number of points associated with any node 00126 // If negative this is ignored, otherwise if a split would produce a child 00127 // node less than this, then the split does not occur and the branch is 00128 // terminated 00129 int min_node_size_; 00130 00131 //: Uniform sampler on 0,1 (for bootstrapping) 00132 mutable vnl_random random_sampler_; 00133 00134 //: Pointer to storage of point indices for each bootstrapped tree 00135 // Can be used for out of bag estimates 00136 // Saves for tree i the indices of all points used in its training 00137 // Note the storage is supplied from outside this class, as this is a kind of bolt-on 00138 vcl_vector<vcl_vector<unsigned > >* poob_indices_; 00139 private: 00140 //: Does the builder calculate the error on the training set? 00141 bool calc_test_error_; 00142 }; 00143 00144 00145 #endif // clsfy_random_forest_builder_h_