00001 // This is mul/clsfy/clsfy_direct_boost.h 00002 #ifndef clsfy_direct_boost_h_ 00003 #define clsfy_direct_boost_h_ 00004 //: 00005 // \file 00006 // \brief Classifier using adaboost on combinations of simple 1D classifiers 00007 // \author Tim Cootes 00008 00009 #include <clsfy/clsfy_classifier_base.h> 00010 #include <clsfy/clsfy_classifier_1d.h> 00011 #include <vnl/vnl_vector.h> 00012 #include <vcl_iosfwd.h> 00013 00014 //: Classifier using adaboost on combinations of simple 1D classifiers 00015 // Uses a weighted combination of 1D classifiers applied to the 00016 // elements of the input vector. 00017 class clsfy_direct_boost : public clsfy_classifier_base 00018 { 00019 protected: 00020 00021 //: The classifiers in order 00022 vcl_vector<clsfy_classifier_1d*> classifier_1d_; 00023 00024 //: Coefficients applied to each classifier 00025 vcl_vector<double> wts_; 00026 00027 //: Index of input vector appropriate for each classifier 00028 vcl_vector<int> index_; 00029 00030 //: Thresholds given variable number of weak classifiers. 00031 // ie threshes[nc-1] is the threshold when using nc weak classifiers 00032 vcl_vector<double> threshes_; 00033 00034 //: number of classifiers used 00035 int n_clfrs_used_; 00036 00037 //: dimensionality of data. 00038 // (ie size of input vectors v, ie the total number of different features) 00039 int n_dims_; 00040 00041 //================protected methods ================================= 00042 00043 //: Delete objects on heap 00044 void delete_stuff(); 00045 00046 public: 00047 00048 //: Default constructor 00049 clsfy_direct_boost(); 00050 00051 //: Copy constructor 00052 clsfy_direct_boost(const clsfy_direct_boost&); 00053 00054 //: Copy operator 00055 clsfy_direct_boost& operator=(const clsfy_direct_boost&); 00056 00057 //: Destructor 00058 ~clsfy_direct_boost(); 00059 00060 //: Comparison 00061 bool operator==(const clsfy_direct_boost& x) const; 00062 00063 //: Clear all wts and classifiers 00064 void clear(); 00065 00066 //: Add classifier and alpha value 00067 void add_one_classifier(clsfy_classifier_1d* c1d, double wt, int index); 00068 00069 //: Set number of classifiers used (when applying strong classifier) 00070 void set_n_clfrs_used(unsigned int x) {if (x <= wts_.size()) n_clfrs_used_ = x;} 00071 00072 //: Access 00073 int n_clfrs_used() const {return n_clfrs_used_; } 00074 00075 //: Add one threshold 00076 void add_one_threshold(double thresh); 00077 00078 //: Add final threshold 00079 void add_final_threshold(double thresh); 00080 00081 //: Find the posterior probability of the input being in the positive class. 00082 // The result is outputs(0) 00083 virtual void class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const; 00084 00085 //: Classify the input vector. 00086 // Returns a number between 0 and nClasses-1 inclusive to represent the most likely class 00087 virtual unsigned classify(const vnl_vector<double> &input) const; 00088 00089 //: Log likelihood of being in the positive class. 00090 // Class probability = 1 / (1+exp(-log_l)) 00091 virtual double log_l(const vnl_vector<double> &input) const; 00092 00093 //: The dimensionality of input vectors. 00094 virtual unsigned n_dims() const { return n_dims_;} 00095 00096 //: Set number of classifiers used (when applying strong classifier) 00097 void set_n_dims(unsigned x) {n_dims_ = x;} 00098 00099 //: The number of possible output classes. 00100 // 1 indicates a binary classifier 00101 virtual unsigned n_classes() const { return 1;} 00102 00103 //: Set parameters. Clones taken of *classifier[i] 00104 void set_parameters(const vcl_vector<clsfy_classifier_1d*>& classifier, 00105 const vcl_vector<double>& threshes, 00106 const vcl_vector<double>& wts, 00107 const vcl_vector<int>& index); 00108 00109 //: Access functions 00110 const vcl_vector<clsfy_classifier_1d*>& classifiers() const 00111 {return classifier_1d_;} 00112 00113 const vcl_vector<double>& wts() const {return wts_;} 00114 00115 const vcl_vector<int>& index() const {return index_;} 00116 00117 const vcl_vector<double>& threshes() const 00118 {return threshes_;} 00119 00120 //: Version number for I/O 00121 short version_no() const; 00122 00123 //: Name of the class 00124 virtual vcl_string is_a() const; 00125 00126 //: Name of the class 00127 virtual bool is_class(vcl_string const& s) const; 00128 00129 //: Print class to os 00130 virtual void print_summary(vcl_ostream& os) const; 00131 00132 //: Save class to a binary File Stream 00133 virtual void b_write(vsl_b_ostream& bfs) const; 00134 00135 //: Create a deep copy. 00136 // Client is responsible for deleting returned object. 00137 virtual clsfy_classifier_base* clone() const 00138 { return new clsfy_direct_boost(*this); } 00139 00140 //: Load the class from a Binary File Stream 00141 virtual void b_read(vsl_b_istream& bfs); 00142 }; 00143 00144 #endif // clsfy_direct_boost_h_