contrib/mul/clsfy/clsfy_binary_hyperplane_gmrho_builder.h
Go to the documentation of this file.
00001 #ifndef clsfy_binary_hyperplane_gmrho_builder_h
00002 #define clsfy_binary_hyperplane_gmrho_builder_h
00003 //:
00004 // \file
00005 // \author Martin Roberts
00006 // \brief Builder for linear 2-state classifier, using a sigmoidal Geman-McClure rho function
00007 
00008 #include <vcl_string.h>
00009 #include <vcl_iosfwd.h> // for std::ostream
00010 // not used? #include <vcl_functional.h>
00011 #include <vnl/io/vnl_io_vector.h>
00012 #include <vnl/io/vnl_io_matrix.h>
00013 #include <vnl/vnl_matrix.h>
00014 #include <vnl/vnl_vector.h>
00015 #include <clsfy/clsfy_binary_hyperplane_ls_builder.h>
00016 
00017 //=======================================================================
00018 
00019 
00020 //: Builder for linear 2-state classifier
00021 //  Uses a Geman-McClure robust function, rather than a least squares fit, on points that
00022 //  are not mis-classified
00023 //  This increases the weighting given to points near the boundary in determining its fit
00024 //  A conventional least squares fit is perform first to determine a starting solution,
00025 //  and also the sigma scaling factor used in the GM function.
00026 //  Several iterations are performed during which sigma is reduced
00027 // (i.e. deterministic annealing), to try and avoid local minima
00028 
00029 class clsfy_binary_hyperplane_gmrho_builder  : public clsfy_binary_hyperplane_ls_builder
00030 {
00031  private:
00032   //: The classifier weights (weight N is the constant)
00033   mutable vnl_vector<double> weights_;
00034 
00035   //: Number of training examples (data.rows())
00036   mutable unsigned num_examples_;
00037   //: Number of variables (data.cols())
00038   mutable unsigned num_vars_;
00039   //: Tolerance for non-linear optimiser convergence
00040   mutable double epsilon_;
00041 
00042   //: should sigma be estimated during the build or a pre-defined value used
00043   bool auto_estimate_sigma_;
00044   //: use this for sigma if auto_estimate_sigma is true
00045   double sigma_preset_;
00046 
00047   //: Estimate the scale (sigma) used in the Geman-McClure function
00048   //This is increased by the mis-classification overlap region if any
00049   double estimate_sigma(const vnl_matrix<double>& data,
00050                         const vnl_vector<double>& y) const;
00051   //: Determine the weights for the hyperplane
00052   void determine_weights(const vnl_matrix<double>& data,
00053                          const vnl_vector<double>& y,
00054                          double sigma) const;
00055  public:
00056 
00057   // Dflt ctor
00058   clsfy_binary_hyperplane_gmrho_builder():
00059       clsfy_binary_hyperplane_ls_builder(),
00060       num_examples_(0),num_vars_(0),epsilon_(1.0E-8),
00061       auto_estimate_sigma_(true),sigma_preset_(1.0) {}
00062 
00063   //: Build a linear classifier, with the given data.
00064   // Return the mean error over the training set.
00065   double build(clsfy_classifier_base &classifier,
00066                mbl_data_wrapper<vnl_vector<double> > &inputs,
00067                const vcl_vector<unsigned> &outputs) const;
00068 
00069   //: Build model from data
00070   // Return the mean error over the training set.
00071   // For this classifiers, you must nClasses==1 to
00072   // indicate a binary classifier
00073   virtual double build(clsfy_classifier_base& model,
00074                        mbl_data_wrapper<vnl_vector<double> >& inputs,
00075                        unsigned nClasses,
00076                        const vcl_vector<unsigned> &outputs) const;
00077 
00078   //: Version number for I/O
00079   short version_no() const;
00080 
00081   //: Name of the class
00082   vcl_string is_a() const;
00083 
00084   //: Name of the class
00085   virtual bool is_class(vcl_string const& s) const;
00086 
00087   //: Print class to os
00088   void print_summary(vcl_ostream& os) const;
00089 
00090   //: Create a deep copy.
00091   // client is responsible for deleting returned object.
00092   virtual clsfy_builder_base* clone() const;
00093 
00094   //: should sigma be estimate during the build or a pre-defined value used
00095   void set_auto_estimate_sigma(bool bAuto) {auto_estimate_sigma_ = bAuto;}
00096   //: use this for sigma if auto_estimate_sigma is true
00097   void set_sigma_preset(double sigma_preset) {sigma_preset_ = sigma_preset;}
00098 
00099   virtual void b_write(vsl_b_ostream &) const;
00100   virtual void b_read(vsl_b_istream &);
00101 };
00102 
00103 #endif // clsfy_binary_hyperplane_gmrho_builder_h