contrib/mul/clsfy/clsfy_simple_adaboost.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_simple_adaboost.h
00002 #ifndef clsfy_simple_adaboost_h_
00003 #define clsfy_simple_adaboost_h_
00004 //:
00005 // \file
00006 // \brief Classifier using adaboost on combinations of simple 1D classifiers
00007 // \author Tim Cootes
00008 
00009 #include <clsfy/clsfy_classifier_base.h>
00010 #include <clsfy/clsfy_classifier_1d.h>
00011 #include <vnl/vnl_vector.h>
00012 #include <vcl_iosfwd.h>
00013 
00014 //: Classifier using adaboost on combinations of simple 1D classifiers
00015 //  Uses a weighted combination of 1D classifiers applied to the
00016 //  elements of the input vector.
00017 class clsfy_simple_adaboost : public clsfy_classifier_base
00018 {
00019  protected:
00020 
00021   //: The classifiers in order
00022   vcl_vector<clsfy_classifier_1d*> classifier_1d_;
00023 
00024   //: Coefficients applied to each classifier
00025   vcl_vector<double> alphas_;
00026 
00027   //: Index of input vector appropriate for each classifier
00028   vcl_vector<int> index_;
00029 
00030   //: number of classifiers used
00031   int n_clfrs_used_;
00032 
00033   //: dimensionality of data
00034   // (i.e. size of input vectors v, ie the total number of different features)
00035   int n_dims_;
00036 
00037 //================protected methods =================================
00038 
00039   //: Delete objects on heap
00040   void delete_stuff();
00041 
00042  public:
00043 
00044   //: Default constructor
00045   clsfy_simple_adaboost();
00046 
00047   //: Copy constructor
00048   clsfy_simple_adaboost(const clsfy_simple_adaboost&);
00049 
00050   //: Copy operator
00051   clsfy_simple_adaboost& operator=(const clsfy_simple_adaboost&);
00052 
00053   //: Destructor
00054   ~clsfy_simple_adaboost();
00055 
00056   //: Comparison
00057   bool operator==(const clsfy_simple_adaboost& x) const;
00058 
00059   //: Clear all alphas and classifiers
00060   void clear();
00061 
00062   //: Add classifier and alpha value
00063   void add_classifier(clsfy_classifier_1d* c1d, double alpha, int index);
00064 
00065   //: Set number of classifiers used (when applying strong classifier)
00066   void set_n_clfrs_used(unsigned int x) {if (x <= alphas_.size()) n_clfrs_used_ = x;}
00067 
00068   //: Access
00069   int n_clfrs_used() const {return n_clfrs_used_; }
00070 
00071   //: Find the posterior probability of the input being in the positive class.
00072   // The result is outputs(0)
00073   virtual void class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const;
00074 
00075   //: Classify the input vector.
00076   // Returns a number between 0 and nClasses-1 inclusive to represent the most likely class
00077   virtual unsigned classify(const vnl_vector<double> &input) const;
00078 
00079   //: Log likelihood of being in the positive class.
00080   // Class probability = 1 / (1+exp(-log_l))
00081   virtual double log_l(const vnl_vector<double> &input) const;
00082 
00083   //: The dimensionality of input vectors.
00084   virtual unsigned n_dims() const { return n_dims_;}
00085 
00086    //: Set number of classifiers used (when applying strong classifier)
00087   void set_n_dims(unsigned x) {n_dims_ = x;}
00088 
00089   //: The number of possible output classes.
00090   // 1 indicates a binary classifier
00091   virtual unsigned n_classes() const { return 1;}
00092 
00093   //: Set parameters.  Clones taken of *classifier[i]
00094   void set_parameters(const vcl_vector<clsfy_classifier_1d*>& classifier,
00095                       const vcl_vector<double>& alphas,
00096                       const vcl_vector<int>& index);
00097 
00098   //: Access functions
00099   const vcl_vector<clsfy_classifier_1d*>& classifiers() const
00100     {return classifier_1d_;}
00101 
00102   const vcl_vector<double>& alphas() const
00103     {return alphas_;}
00104 
00105   const vcl_vector<int>& index() const
00106     {return index_;}
00107 
00108   //: Version number for I/O
00109   short version_no() const;
00110 
00111   //: Name of the class
00112   virtual vcl_string is_a() const;
00113 
00114   //: Name of the class
00115   virtual bool is_class(vcl_string const& s) const;
00116 
00117   //: Print class to os
00118   virtual void print_summary(vcl_ostream& os) const;
00119 
00120   //: Save class to a binary File Stream
00121   virtual void b_write(vsl_b_ostream& bfs) const;
00122 
00123   //: Create a deep copy.
00124   // Client is responsible for deleting returned object.
00125   virtual clsfy_classifier_base* clone() const
00126   { return new clsfy_simple_adaboost(*this); }
00127 
00128   //: Load the class from a Binary File Stream
00129   virtual void b_read(vsl_b_istream& bfs);
00130 };
00131 
00132 #endif // clsfy_simple_adaboost_h_