contrib/mul/clsfy/clsfy_binary_hyperplane_logit_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_hyperplane_logit_builder.cxx
00002 #include "clsfy_binary_hyperplane_logit_builder.h"
00003 //:
00004 // \file
00005 // \brief Linear classifier builder using a logit loss function
00006 // \author Tim Cootes
00007 // \date 18 Jul 2009
00008 
00009 //=======================================================================
00010 
00011 #include <vcl_string.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_algorithm.h>
00017 #include <vcl_numeric.h>
00018 #include <vnl/vnl_vector_ref.h>
00019 #include <vnl/algo/vnl_lbfgs.h>
00020 #include <clsfy/clsfy_logit_loss_function.h>
00021 
00022 clsfy_binary_hyperplane_logit_builder::clsfy_binary_hyperplane_logit_builder():
00023       clsfy_binary_hyperplane_ls_builder(),
00024       alpha_(1e-6),min_p_(0.001) {}
00025 
00026 //: Weighting on regularisation term
00027 void clsfy_binary_hyperplane_logit_builder::set_alpha(double a)
00028 {
00029   alpha_=a;
00030 }
00031 
00032 //: Min prob to be returned by classifier
00033 void clsfy_binary_hyperplane_logit_builder::set_min_p(double p)
00034 {
00035   min_p_=p;
00036 }
00037 
00038 //---------------------------------------------------------------------
00039 //------- The builder member functions --------------------------------
00040 //---------------------------------------------------------------------
00041 
00042 //: Build a linear classifier, with the given data.
00043 // Return the mean error over the training set.
00044 // n_classes must be 1.
00045 double clsfy_binary_hyperplane_logit_builder::build(clsfy_classifier_base& classifier,
00046                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00047                                           unsigned n_classes,
00048                                           const vcl_vector<unsigned>& outputs) const
00049 {
00050   assert (n_classes == 1);
00051   return clsfy_binary_hyperplane_logit_builder::build(classifier, inputs, outputs);
00052 }
00053 
00054 //: Build a linear hyperplane classifier with the given data.
00055 double clsfy_binary_hyperplane_logit_builder::build(clsfy_classifier_base& classifier,
00056                                         mbl_data_wrapper<vnl_vector<double> >& inputs,
00057                                         const vcl_vector<unsigned>& outputs) const
00058 {
00059   // First let the base class get us a starting solution
00060   clsfy_binary_hyperplane_ls_builder::build( classifier,inputs,outputs);
00061 
00062   vcl_cout<<"Initial error:"<< clsfy_test_error(classifier, inputs, outputs) <<vcl_endl;
00063 
00064   unsigned n_egs = inputs.size();
00065   if (n_egs == 0)
00066   {
00067       vcl_cerr<<"WARNING - clsfy_binary_hyperplane_logit_builder::build called with no data\n";
00068       return 0.0;
00069   }
00070 
00071   assert(classifier.is_a()=="clsfy_binary_hyperplane");
00072   clsfy_binary_hyperplane& plane = static_cast<clsfy_binary_hyperplane&>(classifier);
00073 
00074   // Set initial weights using initial LS hyperplane
00075   unsigned n_dim = plane.n_dims();
00076   vnl_vector<double> w(n_dim+1);
00077   w[0]=-1*plane.bias();
00078   for (unsigned i=0;i<n_dim;++i) w[1+i]=plane.weights()[i];
00079 
00080   // Set up cost function
00081   vnl_vector<double> c(n_egs);
00082   for (unsigned i=0;i<n_egs;++i) c[i]=2.0*(outputs[i]-0.5);  // =+/-1
00083   clsfy_quad_regulariser quad_reg(alpha_);
00084   clsfy_logit_loss_function cost_fn(inputs,c,min_p_,&quad_reg);
00085 
00086   // Minimise it
00087   vnl_lbfgs optimizer(cost_fn);
00088   optimizer.set_verbose(true);
00089   optimizer.set_f_tolerance(1e-7);
00090   optimizer.set_x_tolerance(1e-5);
00091   if (!optimizer.minimize(w))
00092   {
00093     vcl_cerr<<"vnl_lbfgs optimisation failed!"<<vcl_endl;
00094     vcl_cerr<<"Failure code: "<<optimizer.get_failure_code()<<vcl_endl;
00095   }
00096 
00097   vnl_vector_ref<double> new_wts(n_dim,&w[1]);
00098 
00099   plane.set(new_wts,-w[0]);
00100 
00101   return clsfy_test_error(classifier, inputs, outputs);
00102 }
00103 
00104 //=======================================================================
00105 
00106 void clsfy_binary_hyperplane_logit_builder::b_write(vsl_b_ostream &bfs) const
00107 {
00108   const int version_no=1;
00109   vsl_b_write(bfs, version_no);
00110   clsfy_binary_hyperplane_ls_builder::b_write(bfs);
00111   vsl_b_write(bfs,min_p_);
00112   vsl_b_write(bfs,alpha_);
00113 }
00114 
00115 //=======================================================================
00116 
00117 void clsfy_binary_hyperplane_logit_builder::b_read(vsl_b_istream &bfs)
00118 {
00119   if (!bfs) return;
00120 
00121   short version;
00122   vsl_b_read(bfs,version);
00123   switch (version)
00124   {
00125     case (1):
00126       clsfy_binary_hyperplane_ls_builder::b_read(bfs);
00127       vsl_b_read(bfs,min_p_);
00128       vsl_b_read(bfs,alpha_);
00129       break;
00130     default:
00131       vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_logit_builder::b_read(vsl_b_istream&)\n"
00132                 << "           Unknown version number "<< version << '\n';
00133       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00134   }
00135 }
00136 
00137 //=======================================================================
00138 
00139 vcl_string clsfy_binary_hyperplane_logit_builder::is_a() const
00140 {
00141   return vcl_string("clsfy_binary_hyperplane_logit_builder");
00142 }
00143 
00144 //=======================================================================
00145 
00146 bool clsfy_binary_hyperplane_logit_builder::is_class(vcl_string const& s) const
00147 {
00148   return s == clsfy_binary_hyperplane_logit_builder::is_a() || clsfy_binary_hyperplane_ls_builder::is_class(s);
00149 }
00150 
00151 //=======================================================================
00152 
00153 short clsfy_binary_hyperplane_logit_builder::version_no() const
00154 {
00155   return 1;
00156 }
00157 
00158 //=======================================================================
00159 
00160 void clsfy_binary_hyperplane_logit_builder::print_summary(vcl_ostream& os) const
00161 {
00162   os << is_a();
00163 }
00164 
00165 //=======================================================================
00166 clsfy_builder_base* clsfy_binary_hyperplane_logit_builder::clone() const
00167 {
00168   return new clsfy_binary_hyperplane_logit_builder(*this);
00169 }
00170