00001
00002 #include "clsfy_binary_hyperplane_gmrho_builder.h"
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <vcl_string.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_algorithm.h>
00017 #include <vcl_numeric.h>
00018 #include <vcl_cstddef.h>
00019 #include <vnl/vnl_vector_ref.h>
00020 #include <vnl/algo/vnl_lbfgs.h>
00021
00022
00023
00024 namespace clsfy_binary_hyperplane_gmrho_builder_helpers
00025 {
00026
00027 class gmrho_sum : public vnl_cost_function
00028 {
00029
00030 const vnl_matrix<double>& x_;
00031
00032 const vnl_vector<double>& y_;
00033
00034 double sigma_;
00035
00036 double var_;
00037
00038 unsigned num_examples_;
00039
00040 unsigned num_vars_;
00041
00042 double alpha_;
00043
00044 double beta_;
00045 public:
00046
00047 gmrho_sum(const vnl_matrix<double>& x,
00048 const vnl_vector<double>& y,double sigma=1);
00049
00050
00051 void set_sigma(double sigma);
00052
00053
00054 virtual double f(vnl_vector<double> const& w);
00055
00056
00057 virtual void gradf(vnl_vector<double> const& x, vnl_vector<double>& gradient);
00058 };
00059
00060
00061 class gm_grad_accum
00062 {
00063 const double* px_;
00064 const double wt_;
00065 public:
00066 gm_grad_accum(const double* px,double wt) : px_(px),wt_(wt) {}
00067 void operator()(double& grad)
00068 {
00069 grad += (*px_++) * wt_;
00070 }
00071 };
00072
00073
00074 class category_value
00075 {
00076
00077
00078 public:
00079 category_value(vcl_size_t , vcl_size_t )
00080
00081
00082 {}
00083
00084 double operator()(const unsigned& classNum)
00085 {
00086
00087 return classNum ? 1.0 : -1.0;
00088 }
00089 };
00090 };
00091
00092
00093
00094
00095
00096
00097
00098 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00099 mbl_data_wrapper<vnl_vector<double> >& inputs,
00100 unsigned n_classes,
00101 const vcl_vector<unsigned>& outputs) const
00102 {
00103 assert (n_classes == 1);
00104 return clsfy_binary_hyperplane_gmrho_builder::build(classifier, inputs, outputs);
00105 }
00106
00107
00108
00109
00110 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00111 mbl_data_wrapper<vnl_vector<double> >& inputs,
00112 const vcl_vector<unsigned>& outputs) const
00113 {
00114 using clsfy_binary_hyperplane_gmrho_builder_helpers::category_value;
00115
00116
00117 clsfy_binary_hyperplane_ls_builder::build( classifier,inputs,outputs);
00118
00119 num_examples_ = inputs.size();
00120 if (num_examples_ == 0)
00121 {
00122 vcl_cerr<<"WARNING - clsfy_binary_hyperplane_gmrho_builder::build called with no data\n";
00123 return 0.0;
00124 }
00125
00126
00127 inputs.reset();
00128 num_vars_ = inputs.current().size();
00129 vnl_matrix<double> data(num_examples_,num_vars_,0.0);
00130 unsigned i=0;
00131 do
00132 {
00133 double* row=data[i++];
00134 vcl_copy(inputs.current().begin(),inputs.current().end(),row);
00135 } while (inputs.next());
00136
00137
00138 vnl_vector<double> y(num_examples_,0.0);
00139 vcl_transform(outputs.begin(),outputs.end(),
00140 y.begin(),
00141 category_value(vcl_count(outputs.begin(),outputs.end(),1u),outputs.size()));
00142 weights_.set_size(num_vars_+1);
00143
00144
00145 clsfy_binary_hyperplane& hyperplane = dynamic_cast<clsfy_binary_hyperplane &>(classifier);
00146
00147 weights_.update(hyperplane.weights(),0);
00148 weights_[num_vars_] = hyperplane.bias();
00149
00150
00151 double sigma_scale_target = sigma_preset_;
00152 if (auto_estimate_sigma_)
00153 sigma_scale_target=estimate_sigma(data,y);
00154
00155
00156
00157 double kappa = 5.0;
00158 const double alpha_anneal=0.75;
00159
00160 int N = 1+int(vcl_log(1.1/kappa)/vcl_log(alpha_anneal));
00161 if (N<1) N=1;
00162 double sigma_scale = kappa * sigma_scale_target;
00163
00164 epsilon_ = 1.0E-4;
00165 for (int ianneal=0;ianneal<N;++ianneal)
00166 {
00167
00168 determine_weights(data,y,sigma_scale);
00169
00170 sigma_scale *= alpha_anneal;
00171 }
00172
00173 epsilon_ = 1.0E-8;
00174
00175
00176
00177
00178 for (unsigned iter=0; iter<(auto_estimate_sigma_ ? 2u : 1u); ++iter)
00179 {
00180 if (auto_estimate_sigma_)
00181 sigma_scale_target=estimate_sigma(data,y);
00182 else
00183 sigma_scale_target = sigma_preset_;
00184
00185 determine_weights(data,y,sigma_scale_target);
00186 }
00187
00188 vnl_vector_ref<double > weights(num_vars_,weights_.data_block());
00189 hyperplane.set(weights, weights_[num_vars_]);
00190
00191 return clsfy_test_error(classifier, inputs, outputs);
00192 }
00193
00194 void clsfy_binary_hyperplane_gmrho_builder::determine_weights(const vnl_matrix<double>& data,
00195 const vnl_vector<double >& y,
00196 double sigma) const
00197 {
00198
00199
00200 clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum costFn(data,y,sigma);
00201
00202
00203 vnl_lbfgs cgMinimiser(costFn);
00204
00205 cgMinimiser.set_f_tolerance(epsilon_);
00206 cgMinimiser.set_x_tolerance(epsilon_);
00207
00208 cgMinimiser.minimize(weights_);
00209 }
00210
00211 double clsfy_binary_hyperplane_gmrho_builder::estimate_sigma(const vnl_matrix<double>& data,
00212 const vnl_vector<double >& y) const
00213 {
00214
00215
00216
00217
00218 vcl_vector<double > falsePosScores;
00219 vcl_vector<double > falseNegScores;
00220
00221 double b=weights_[num_vars_];
00222 for (unsigned i=0; i<num_examples_;++i)
00223 {
00224 const double* px=data[i];
00225 double yval = y[i];
00226 double ypred = vcl_inner_product(px,px+num_vars_,weights_.begin(),0.0) - b ;
00227 if (yval>0.0)
00228 {
00229 if (ypred<0.0)
00230 {
00231 falseNegScores.push_back(vcl_fabs(ypred));
00232 }
00233 }
00234 else
00235 {
00236 if (ypred>0.0)
00237 {
00238 falsePosScores.push_back(vcl_fabs(ypred));
00239 }
00240 }
00241 }
00242 double sigma=1.0;
00243 double delta0=0.0;
00244 if (!falsePosScores.empty())
00245 {
00246 vcl_vector<double >::iterator medianIter=falsePosScores.begin() + falsePosScores.size()/2;
00247 vcl_nth_element(falsePosScores.begin(),medianIter,falsePosScores.end());
00248 delta0 = (*medianIter);
00249 }
00250 double delta1=0.0;
00251 if (!falseNegScores.empty())
00252 {
00253 vcl_vector<double >::iterator medianIter=falseNegScores.begin() + falseNegScores.size()/2;
00254 vcl_nth_element(falseNegScores.begin(),medianIter,falseNegScores.end());
00255 delta1 = (*medianIter);
00256 }
00257 sigma += vcl_max(delta0,delta1);
00258
00259 sigma *= vcl_sqrt(3.0);
00260 return sigma;
00261 }
00262
00263
00264
00265 void clsfy_binary_hyperplane_gmrho_builder::b_write(vsl_b_ostream &bfs) const
00266 {
00267 const int version_no=1;
00268 vsl_b_write(bfs, version_no);
00269 clsfy_binary_hyperplane_ls_builder::b_write(bfs);
00270 }
00271
00272
00273
00274 void clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream &bfs)
00275 {
00276 if (!bfs) return;
00277
00278 short version;
00279 vsl_b_read(bfs,version);
00280 switch (version)
00281 {
00282 case (1):
00283 clsfy_binary_hyperplane_ls_builder::b_read(bfs);
00284 break;
00285 default:
00286 vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream&)\n"
00287 << " Unknown version number "<< version << '\n';
00288 bfs.is().clear(vcl_ios::badbit);
00289 }
00290 }
00291
00292
00293
00294 vcl_string clsfy_binary_hyperplane_gmrho_builder::is_a() const
00295 {
00296 return vcl_string("clsfy_binary_hyperplane_gmrho_builder");
00297 }
00298
00299
00300
00301 bool clsfy_binary_hyperplane_gmrho_builder::is_class(vcl_string const& s) const
00302 {
00303 return s == clsfy_binary_hyperplane_gmrho_builder::is_a() || clsfy_binary_hyperplane_ls_builder::is_class(s);
00304 }
00305
00306
00307
00308 short clsfy_binary_hyperplane_gmrho_builder::version_no() const
00309 {
00310 return 1;
00311 }
00312
00313
00314
00315 void clsfy_binary_hyperplane_gmrho_builder::print_summary(vcl_ostream& os) const
00316 {
00317 os << is_a();
00318 }
00319
00320
00321 clsfy_builder_base* clsfy_binary_hyperplane_gmrho_builder::clone() const
00322 {
00323 return new clsfy_binary_hyperplane_gmrho_builder(*this);
00324 }
00325
00326
00327
00328
00329
00330
00331
00332 clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gmrho_sum(const vnl_matrix<double>& x,
00333 const vnl_vector<double>& y,
00334 double sigma):
00335 vnl_cost_function(x.cols()+1),
00336 x_(x),y_(y),sigma_(1.0),var_(1.0),num_examples_(x.rows()),num_vars_(x.cols())
00337 {
00338 set_sigma(sigma);
00339 }
00340
00341 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::set_sigma(double sigma)
00342 {
00343 sigma_ = sigma;
00344 var_ = sigma*sigma;
00345 double s=1.0+var_;
00346 s = s*s;
00347 alpha_ = var_/s;
00348 beta_ = 1.0/s;
00349 }
00350
00351
00352
00353 double clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::f(vnl_vector<double> const& w)
00354 {
00355
00356 double sum=0.0;
00357 double b=w[num_vars_];
00358 for (unsigned i=0; i<num_examples_;++i)
00359 {
00360 const double* px=x_[i];
00361 double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00362 double e = y_[i] - pred;
00363 double e2 = e*e;
00364 if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00365 ((y_[i] < 0.0) && (e >= -1.0)) )
00366 {
00367
00368
00369 sum += e2/(e2+var_);
00370 }
00371 else
00372 {
00373
00374
00375 sum += alpha_*e2 + beta_;
00376 }
00377 }
00378 return sum;
00379 }
00380
00381
00382 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gradf(vnl_vector<double> const& w,
00383 vnl_vector<double>& gradient)
00384 {
00385 using clsfy_binary_hyperplane_gmrho_builder_helpers::gm_grad_accum;
00386 double b=w[num_vars_];
00387 gradient.fill(0.0);
00388
00389 for (unsigned i=0; i<num_examples_;++i)
00390 {
00391 const double* px=x_[i];
00392 double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00393
00394 double e = y_[i] - pred;
00395 double e2 = e*e;
00396 double wt=1.0;
00397 if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00398 ((y_[i] < 0.0) && (e >= -1.0)) )
00399 {
00400 wt = e2 + var_;
00401 }
00402 else
00403 {
00404
00405 wt = 1.0 + var_;
00406 }
00407
00408 double wtInv = -e/(wt*wt);
00409 vcl_for_each(gradient.begin(),gradient.begin()+num_vars_,
00410 gm_grad_accum(px,wtInv));
00411
00412 gradient[num_vars_] += (-wtInv);
00413 }
00414
00415 vcl_transform(gradient.begin(),gradient.end(),gradient.begin(),
00416 vcl_bind2nd(vcl_multiplies<double>(),2.0*var_));
00417 }