contrib/mul/clsfy/clsfy_direct_boost.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_direct_boost.h
00002 #ifndef clsfy_direct_boost_h_
00003 #define clsfy_direct_boost_h_
00004 //:
00005 // \file
00006 // \brief Classifier using adaboost on combinations of simple 1D classifiers
00007 // \author Tim Cootes
00008 
00009 #include <clsfy/clsfy_classifier_base.h>
00010 #include <clsfy/clsfy_classifier_1d.h>
00011 #include <vnl/vnl_vector.h>
00012 #include <vcl_iosfwd.h>
00013 
00014 //: Classifier using adaboost on combinations of simple 1D classifiers
00015 //  Uses a weighted combination of 1D classifiers applied to the
00016 //  elements of the input vector.
00017 class clsfy_direct_boost : public clsfy_classifier_base
00018 {
00019  protected:
00020 
00021   //: The classifiers in order
00022   vcl_vector<clsfy_classifier_1d*> classifier_1d_;
00023 
00024   //: Coefficients applied to each classifier
00025   vcl_vector<double> wts_;
00026 
00027   //: Index of input vector appropriate for each classifier
00028   vcl_vector<int> index_;
00029 
00030   //: Thresholds given variable number of weak classifiers.
00031   // ie threshes[nc-1] is the threshold when using nc weak classifiers
00032   vcl_vector<double> threshes_;
00033 
00034   //: number of classifiers used
00035   int n_clfrs_used_;
00036 
00037   //: dimensionality of data.
00038   // (ie size of input vectors v, ie the total number of different features)
00039   int n_dims_;
00040 
00041 //================protected methods =================================
00042 
00043   //: Delete objects on heap
00044   void delete_stuff();
00045 
00046  public:
00047 
00048   //: Default constructor
00049   clsfy_direct_boost();
00050 
00051   //: Copy constructor
00052   clsfy_direct_boost(const clsfy_direct_boost&);
00053 
00054   //: Copy operator
00055   clsfy_direct_boost& operator=(const clsfy_direct_boost&);
00056 
00057   //: Destructor
00058   ~clsfy_direct_boost();
00059 
00060   //: Comparison
00061   bool operator==(const clsfy_direct_boost& x) const;
00062 
00063   //: Clear all wts and classifiers
00064   void clear();
00065 
00066   //: Add classifier and alpha value
00067   void add_one_classifier(clsfy_classifier_1d* c1d, double wt, int index);
00068 
00069   //: Set number of classifiers used (when applying strong classifier)
00070   void set_n_clfrs_used(unsigned int x) {if (x <= wts_.size()) n_clfrs_used_ = x;}
00071 
00072   //: Access
00073   int n_clfrs_used() const {return n_clfrs_used_; }
00074 
00075    //: Add one threshold
00076   void add_one_threshold(double thresh);
00077 
00078   //: Add final threshold
00079   void add_final_threshold(double thresh);
00080 
00081   //: Find the posterior probability of the input being in the positive class.
00082   // The result is outputs(0)
00083   virtual void class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const;
00084 
00085   //: Classify the input vector.
00086   // Returns a number between 0 and nClasses-1 inclusive to represent the most likely class
00087   virtual unsigned classify(const vnl_vector<double> &input) const;
00088 
00089   //: Log likelihood of being in the positive class.
00090   // Class probability = 1 / (1+exp(-log_l))
00091   virtual double log_l(const vnl_vector<double> &input) const;
00092 
00093   //: The dimensionality of input vectors.
00094   virtual unsigned n_dims() const { return n_dims_;}
00095 
00096    //: Set number of classifiers used (when applying strong classifier)
00097   void set_n_dims(unsigned x) {n_dims_ = x;}
00098 
00099   //: The number of possible output classes.
00100   // 1 indicates a binary classifier
00101   virtual unsigned n_classes() const { return 1;}
00102 
00103   //: Set parameters.  Clones taken of *classifier[i]
00104   void set_parameters(const vcl_vector<clsfy_classifier_1d*>& classifier,
00105                       const vcl_vector<double>& threshes,
00106                       const vcl_vector<double>& wts,
00107                       const vcl_vector<int>& index);
00108 
00109   //: Access functions
00110   const vcl_vector<clsfy_classifier_1d*>& classifiers() const
00111     {return classifier_1d_;}
00112 
00113   const vcl_vector<double>& wts() const {return wts_;}
00114 
00115   const vcl_vector<int>& index() const {return index_;}
00116 
00117   const vcl_vector<double>& threshes() const
00118     {return threshes_;}
00119 
00120   //: Version number for I/O
00121   short version_no() const;
00122 
00123   //: Name of the class
00124   virtual vcl_string is_a() const;
00125 
00126   //: Name of the class
00127   virtual bool is_class(vcl_string const& s) const;
00128 
00129   //: Print class to os
00130   virtual void print_summary(vcl_ostream& os) const;
00131 
00132   //: Save class to a binary File Stream
00133   virtual void b_write(vsl_b_ostream& bfs) const;
00134 
00135   //: Create a deep copy.
00136   // Client is responsible for deleting returned object.
00137   virtual clsfy_classifier_base* clone() const
00138   { return new clsfy_direct_boost(*this); }
00139 
00140   //: Load the class from a Binary File Stream
00141   virtual void b_read(vsl_b_istream& bfs);
00142 };
00143 
00144 #endif // clsfy_direct_boost_h_