contrib/mul/clsfy/clsfy_binary_hyperplane_gmrho_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_hyperplane_gmrho_builder.cxx
00002 #include "clsfy_binary_hyperplane_gmrho_builder.h"
00003 //:
00004 // \file
00005 // \brief Implement a two-class output linear classifier builder using a Geman-McClure robust error function
00006 // \author Martin Roberts
00007 // \date 4 Nov 2006
00008 
00009 //=======================================================================
00010 
00011 #include <vcl_string.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_algorithm.h>
00017 #include <vcl_numeric.h>
00018 #include <vcl_cstddef.h>
00019 #include <vnl/vnl_vector_ref.h>
00020 #include <vnl/algo/vnl_lbfgs.h>
00021 
00022 
00023 //: Some helper stuff, like the error function to be minimised
00024 namespace clsfy_binary_hyperplane_gmrho_builder_helpers
00025 {
00026     //: The cost function, sum Geman-McClure error functions over all training examples
00027     class gmrho_sum : public vnl_cost_function
00028     {
00029         //: Reference to data matrix, one row per training example
00030         const vnl_matrix<double>& x_;
00031         //: Reference to required outputs
00032         const vnl_vector<double>& y_;
00033         //: Scale factor used in Geman-McClure error function
00034         double sigma_;
00035         //: sigma squared
00036         double var_;
00037         //: Number of training examples (x_.rows())
00038         unsigned num_examples_;
00039         //: Number of dimensions (x_.cols())
00040         unsigned num_vars_;
00041         //: var_/(1+var_)^2 - ensures continuity of derivative at hyperplane boundary
00042         double alpha_;
00043         //: 1/(1+var_)^2 - with alpha, ensures continuity of function at hyperplane boundary
00044         double beta_;
00045       public:
00046         //: construct passing in reference to data matrix
00047         gmrho_sum(const vnl_matrix<double>& x,
00048                   const vnl_vector<double>& y,double sigma=1);
00049 
00050         //: reset the scaling factor
00051         void set_sigma(double sigma);
00052 
00053         //:  The main function.  Given the vector of weights parameters vector , compute the value of f(x).
00054         virtual double f(vnl_vector<double> const& w);
00055 
00056         //:  Calculate the gradient of f at parameter vector x.
00057         virtual void gradf(vnl_vector<double> const& x, vnl_vector<double>& gradient);
00058     };
00059 
00060     //: functor to accumulate gradient contributions for given training example
00061     class gm_grad_accum
00062     {
00063         const double* px_;
00064         const double wt_;
00065       public:
00066         gm_grad_accum(const double* px,double wt) : px_(px),wt_(wt) {}
00067         void operator()(double& grad)
00068         {
00069             grad += (*px_++) * wt_;
00070         }
00071     };
00072 
00073     //: Given the class category variable, return the associated regression value (e.g. 1 for class 1, -1 for class 0)
00074     class category_value
00075     {
00076 //        const double y0;
00077 //        const double y1;
00078       public:
00079         category_value(vcl_size_t /*num_category1*/, vcl_size_t /*num_total*/)
00080 //          : y0(-1.0*double(num_total-num_category1)/double(num_total)),
00081 //            y1(double(num_category1)/double(num_total))
00082         {}
00083 
00084         double operator()(const unsigned& classNum)
00085         {
00086             //return classNum ? y1 : y0;
00087             return classNum ? 1.0 : -1.0;
00088         }
00089     };
00090 };
00091 
00092 //-----------------------------------------------------------------------------------------------
00093 //------------------------ The builder member functions ------------------------------------------
00094 //------------------------------------------------------------------------------------------------
00095 //: Build a linear classifier, with the given data.
00096 // Return the mean error over the training set.
00097 // n_classes must be 1.
00098 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00099                                                     mbl_data_wrapper<vnl_vector<double> >& inputs,
00100                                                     unsigned n_classes,
00101                                                     const vcl_vector<unsigned>& outputs) const
00102 {
00103     assert (n_classes == 1);
00104     return clsfy_binary_hyperplane_gmrho_builder::build(classifier, inputs, outputs);
00105 }
00106 
00107 //: Build a linear hyperplane classifier with the given data.
00108 // Reduce the influence of well classified points far into their correct region by
00109 // applying a Geman-McClure robust error function, rather than a least squares fit
00110 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00111                                                     mbl_data_wrapper<vnl_vector<double> >& inputs,
00112                                                     const vcl_vector<unsigned>& outputs) const
00113 {
00114     using clsfy_binary_hyperplane_gmrho_builder_helpers::category_value;
00115 
00116     //First let the base class get us a starting solution
00117     clsfy_binary_hyperplane_ls_builder::build( classifier,inputs,outputs);
00118     //Extract the data into a matrix
00119     num_examples_ = inputs.size();
00120     if (num_examples_ == 0)
00121     {
00122         vcl_cerr<<"WARNING - clsfy_binary_hyperplane_gmrho_builder::build called with no data\n";
00123         return 0.0;
00124     }
00125 
00126     //Now copy from the urggghh data wrapper into a sensible data structure (matrix!)
00127     inputs.reset();
00128     num_vars_ = inputs.current().size();
00129     vnl_matrix<double> data(num_examples_,num_vars_,0.0);
00130     unsigned i=0;
00131     do
00132     {
00133         double* row=data[i++];
00134         vcl_copy(inputs.current().begin(),inputs.current().end(),row);
00135     } while (inputs.next());
00136 
00137     //Set up category regression values determined by output class
00138     vnl_vector<double> y(num_examples_,0.0);
00139     vcl_transform(outputs.begin(),outputs.end(),
00140                   y.begin(),
00141                   category_value(vcl_count(outputs.begin(),outputs.end(),1u),outputs.size()));
00142     weights_.set_size(num_vars_+1);
00143 
00144     //Initialise the weights using the standard least squares fit of my base class
00145     clsfy_binary_hyperplane& hyperplane = dynamic_cast<clsfy_binary_hyperplane &>(classifier);
00146 
00147     weights_.update(hyperplane.weights(),0);
00148     weights_[num_vars_] = hyperplane.bias();
00149 
00150     //Estimate the scaling factor used in the Geman-McClure function
00151     double sigma_scale_target = sigma_preset_;
00152     if (auto_estimate_sigma_)
00153         sigma_scale_target=estimate_sigma(data,y);
00154 
00155     //To avoid local minima perform deterministic annealing starting from a large initial sigma
00156     //Set initial kappa so that everything is an inlier
00157     double kappa = 5.0;
00158     const double alpha_anneal=0.75;
00159     //Num of iterations to reduce back to 10% on top of required sigma
00160     int N = 1+int(vcl_log(1.1/kappa)/vcl_log(alpha_anneal));
00161     if (N<1) N=1;
00162     double sigma_scale = kappa * sigma_scale_target;
00163 
00164     epsilon_ = 1.0E-4; //slacken off convergence tolerance during annealing
00165     for (int ianneal=0;ianneal<N;++ianneal)
00166     {
00167         //Then do it at this sigma
00168         determine_weights(data,y,sigma_scale);
00169         //and then reduce sigma
00170         sigma_scale *= alpha_anneal;
00171     }
00172 
00173     epsilon_ = 1.0E-8; //re-impose a more precise convergence criterion
00174     //Then re-estimate sigma scale and do a final pair of iterations
00175     //as sigma depends on the mis-classification overlap depth
00176 
00177 
00178     for (unsigned iter=0; iter<(auto_estimate_sigma_ ? 2u : 1u); ++iter)
00179     {
00180         if (auto_estimate_sigma_)
00181             sigma_scale_target=estimate_sigma(data,y);
00182         else
00183             sigma_scale_target = sigma_preset_;
00184         //Finally do it at exactly the target sigma
00185         determine_weights(data,y,sigma_scale_target);
00186     }
00187     //And finally copy the parameters into the hyperplane
00188     vnl_vector_ref<double > weights(num_vars_,weights_.data_block());
00189     hyperplane.set(weights, weights_[num_vars_]);
00190 
00191     return clsfy_test_error(classifier, inputs, outputs);
00192 }
00193 
00194 void clsfy_binary_hyperplane_gmrho_builder::determine_weights(const vnl_matrix<double>& data,
00195                                                               const vnl_vector<double >& y,
00196                                                               double sigma) const
00197 {
00198     //Optimise the weights to fit the data to y
00199 
00200     clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum costFn(data,y,sigma);
00201 
00202     //minimise using the quasi-Newton lbfgs method
00203     vnl_lbfgs cgMinimiser(costFn);
00204 
00205     cgMinimiser.set_f_tolerance(epsilon_);
00206     cgMinimiser.set_x_tolerance(epsilon_);
00207 
00208     cgMinimiser.minimize(weights_);
00209 }
00210 
00211 double clsfy_binary_hyperplane_gmrho_builder::estimate_sigma(const vnl_matrix<double>& data,
00212                                                              const vnl_vector<double >& y) const
00213 {
00214     //Sigma is set to root(3) * (1+d), where d is the median distance past zero
00215     //of the misclassified values
00216     //The root(3) is because GM function reduces influence after sigma/sqrt(3)
00217 
00218     vcl_vector<double > falsePosScores;
00219     vcl_vector<double > falseNegScores;
00220 
00221     double b=weights_[num_vars_]; //constant stored as final variable
00222     for (unsigned i=0; i<num_examples_;++i) //Loop over examples (matrix rows)
00223     {
00224         const double* px=data[i];
00225         double yval = y[i];
00226         double ypred = vcl_inner_product(px,px+num_vars_,weights_.begin(),0.0) - b ;
00227         if (yval>0.0)
00228         {
00229             if (ypred<0.0) // mis-classified false negative
00230             {
00231                 falseNegScores.push_back(vcl_fabs(ypred));
00232             }
00233         }
00234         else
00235         {
00236             if (ypred>0.0)//mis-classified false negative
00237             {
00238                 falsePosScores.push_back(vcl_fabs(ypred));
00239             }
00240         }
00241     }
00242     double sigma=1.0;
00243     double delta0=0.0;
00244     if (!falsePosScores.empty())
00245     {
00246         vcl_vector<double >::iterator medianIter=falsePosScores.begin() + falsePosScores.size()/2;
00247         vcl_nth_element(falsePosScores.begin(),medianIter,falsePosScores.end());
00248         delta0 = (*medianIter);
00249     }
00250     double delta1=0.0;
00251     if (!falseNegScores.empty())
00252     {
00253         vcl_vector<double >::iterator medianIter=falseNegScores.begin() + falseNegScores.size()/2;
00254         vcl_nth_element(falseNegScores.begin(),medianIter,falseNegScores.end());
00255         delta1 = (*medianIter);
00256     }
00257     sigma += vcl_max(delta0,delta1);
00258 
00259     sigma *= vcl_sqrt(3.0);
00260     return sigma;
00261 }
00262 
00263 //=======================================================================
00264 
00265 void clsfy_binary_hyperplane_gmrho_builder::b_write(vsl_b_ostream &bfs) const
00266 {
00267     const int version_no=1;
00268     vsl_b_write(bfs, version_no);
00269     clsfy_binary_hyperplane_ls_builder::b_write(bfs);
00270 }
00271 
00272 //=======================================================================
00273 
00274 void clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream &bfs)
00275 {
00276     if (!bfs) return;
00277 
00278     short version;
00279     vsl_b_read(bfs,version);
00280     switch (version)
00281     {
00282         case (1):
00283             clsfy_binary_hyperplane_ls_builder::b_read(bfs);
00284             break;
00285         default:
00286             vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream&)\n"
00287                      << "           Unknown version number "<< version << '\n';
00288             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00289     }
00290 }
00291 
00292 //=======================================================================
00293 
00294 vcl_string clsfy_binary_hyperplane_gmrho_builder::is_a() const
00295 {
00296     return vcl_string("clsfy_binary_hyperplane_gmrho_builder");
00297 }
00298 
00299 //=======================================================================
00300 
00301 bool clsfy_binary_hyperplane_gmrho_builder::is_class(vcl_string const& s) const
00302 {
00303     return s == clsfy_binary_hyperplane_gmrho_builder::is_a() || clsfy_binary_hyperplane_ls_builder::is_class(s);
00304 }
00305 
00306 //=======================================================================
00307 
00308 short clsfy_binary_hyperplane_gmrho_builder::version_no() const
00309 {
00310     return 1;
00311 }
00312 
00313 //=======================================================================
00314 
00315 void clsfy_binary_hyperplane_gmrho_builder::print_summary(vcl_ostream& os) const
00316 {
00317     os << is_a();
00318 }
00319 
00320 //=======================================================================
00321 clsfy_builder_base* clsfy_binary_hyperplane_gmrho_builder::clone() const
00322 {
00323     return new clsfy_binary_hyperplane_gmrho_builder(*this);
00324 }
00325 
00326 //---------------------------------------------------------------------------------------------
00327 //: The error function class
00328 //  This returns a geman-mcclure robust function if the point is correctly classified
00329 // Otherwise the squared error is returned, with coefficient and offset to ensure continuity
00330 // and smoothness at the join
00331 //---------------------------------------------------------------------------------------------
00332 clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gmrho_sum(const vnl_matrix<double>& x,
00333                                                                     const vnl_vector<double>& y,
00334                                                                     double sigma):
00335         vnl_cost_function(x.cols()+1),
00336         x_(x),y_(y),sigma_(1.0),var_(1.0),num_examples_(x.rows()),num_vars_(x.cols())
00337 {
00338     set_sigma(sigma);
00339 }
00340 
00341 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::set_sigma(double sigma)
00342 {
00343     sigma_ = sigma;
00344     var_ = sigma*sigma;
00345     double s=1.0+var_;
00346     s = s*s;
00347     alpha_ = var_/s;
00348     beta_ = 1.0/s;
00349 }
00350 
00351 
00352 //: Return the error sum function
00353 double clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::f(vnl_vector<double> const& w)
00354 {
00355     //Sum the error contributions from each example
00356     double sum=0.0;
00357     double b=w[num_vars_]; //constant stored as final variable
00358     for (unsigned i=0; i<num_examples_;++i) //Loop over examples (matrix rows)
00359     {
00360         const double* px=x_[i];
00361         double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00362         double e =  y_[i] - pred;
00363         double e2 = e*e;
00364         if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00365              ((y_[i] < 0.0) && (e >= -1.0)) )
00366         {
00367             //In the correctly classified region
00368             //So use Geman-McClure function
00369             sum += e2/(e2+var_);
00370         }
00371         else
00372         {
00373             //Misclassified, so keep as quadratic (influence increases with error)
00374             //NB alpha and beta are chosen for continuity of function and gradient at boundary
00375             sum += alpha_*e2 + beta_;
00376         }
00377     }
00378     return sum;
00379 }
00380 
00381 //: Calculate gradient of the error sum function
00382 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gradf(vnl_vector<double> const& w,
00383                                                                      vnl_vector<double>& gradient)
00384 {
00385     using clsfy_binary_hyperplane_gmrho_builder_helpers::gm_grad_accum;
00386     double b=w[num_vars_]; //constant stored as final variable
00387     gradient.fill(0.0);
00388 
00389     for (unsigned i=0; i<num_examples_;++i) //Loop over examples (matrix rows)
00390     {
00391         const double* px=x_[i];
00392         double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00393 
00394         double e =  y_[i] - pred;
00395         double e2 = e*e;
00396         double wt=1.0;
00397         if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00398              ((y_[i] < 0.0) && (e >= -1.0)) )
00399         {
00400             wt = e2 + var_;
00401         }
00402         else
00403         {
00404             //Freeze weight decay once in misclassification region
00405             wt = 1.0 + var_;
00406         }
00407 
00408         double wtInv = -e/(wt*wt);
00409         vcl_for_each(gradient.begin(),gradient.begin()+num_vars_,
00410                      gm_grad_accum(px,wtInv));
00411 
00412         gradient[num_vars_] += (-wtInv); //dg/db, last term is for constant
00413     }
00414     //And multiply everything by 2sigma^2
00415     vcl_transform(gradient.begin(),gradient.end(),gradient.begin(),
00416                   vcl_bind2nd(vcl_multiplies<double>(),2.0*var_));
00417 }