Go to the documentation of this file.00001
00002 #include "clsfy_binary_hyperplane_logit_builder.h"
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <vcl_string.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_algorithm.h>
00017 #include <vcl_numeric.h>
00018 #include <vnl/vnl_vector_ref.h>
00019 #include <vnl/algo/vnl_lbfgs.h>
00020 #include <clsfy/clsfy_logit_loss_function.h>
00021
00022 clsfy_binary_hyperplane_logit_builder::clsfy_binary_hyperplane_logit_builder():
00023 clsfy_binary_hyperplane_ls_builder(),
00024 alpha_(1e-6),min_p_(0.001) {}
00025
00026
00027 void clsfy_binary_hyperplane_logit_builder::set_alpha(double a)
00028 {
00029 alpha_=a;
00030 }
00031
00032
00033 void clsfy_binary_hyperplane_logit_builder::set_min_p(double p)
00034 {
00035 min_p_=p;
00036 }
00037
00038
00039
00040
00041
00042
00043
00044
00045 double clsfy_binary_hyperplane_logit_builder::build(clsfy_classifier_base& classifier,
00046 mbl_data_wrapper<vnl_vector<double> >& inputs,
00047 unsigned n_classes,
00048 const vcl_vector<unsigned>& outputs) const
00049 {
00050 assert (n_classes == 1);
00051 return clsfy_binary_hyperplane_logit_builder::build(classifier, inputs, outputs);
00052 }
00053
00054
00055 double clsfy_binary_hyperplane_logit_builder::build(clsfy_classifier_base& classifier,
00056 mbl_data_wrapper<vnl_vector<double> >& inputs,
00057 const vcl_vector<unsigned>& outputs) const
00058 {
00059
00060 clsfy_binary_hyperplane_ls_builder::build( classifier,inputs,outputs);
00061
00062 vcl_cout<<"Initial error:"<< clsfy_test_error(classifier, inputs, outputs) <<vcl_endl;
00063
00064 unsigned n_egs = inputs.size();
00065 if (n_egs == 0)
00066 {
00067 vcl_cerr<<"WARNING - clsfy_binary_hyperplane_logit_builder::build called with no data\n";
00068 return 0.0;
00069 }
00070
00071 assert(classifier.is_a()=="clsfy_binary_hyperplane");
00072 clsfy_binary_hyperplane& plane = static_cast<clsfy_binary_hyperplane&>(classifier);
00073
00074
00075 unsigned n_dim = plane.n_dims();
00076 vnl_vector<double> w(n_dim+1);
00077 w[0]=-1*plane.bias();
00078 for (unsigned i=0;i<n_dim;++i) w[1+i]=plane.weights()[i];
00079
00080
00081 vnl_vector<double> c(n_egs);
00082 for (unsigned i=0;i<n_egs;++i) c[i]=2.0*(outputs[i]-0.5);
00083 clsfy_quad_regulariser quad_reg(alpha_);
00084 clsfy_logit_loss_function cost_fn(inputs,c,min_p_,&quad_reg);
00085
00086
00087 vnl_lbfgs optimizer(cost_fn);
00088 optimizer.set_verbose(true);
00089 optimizer.set_f_tolerance(1e-7);
00090 optimizer.set_x_tolerance(1e-5);
00091 if (!optimizer.minimize(w))
00092 {
00093 vcl_cerr<<"vnl_lbfgs optimisation failed!"<<vcl_endl;
00094 vcl_cerr<<"Failure code: "<<optimizer.get_failure_code()<<vcl_endl;
00095 }
00096
00097 vnl_vector_ref<double> new_wts(n_dim,&w[1]);
00098
00099 plane.set(new_wts,-w[0]);
00100
00101 return clsfy_test_error(classifier, inputs, outputs);
00102 }
00103
00104
00105
00106 void clsfy_binary_hyperplane_logit_builder::b_write(vsl_b_ostream &bfs) const
00107 {
00108 const int version_no=1;
00109 vsl_b_write(bfs, version_no);
00110 clsfy_binary_hyperplane_ls_builder::b_write(bfs);
00111 vsl_b_write(bfs,min_p_);
00112 vsl_b_write(bfs,alpha_);
00113 }
00114
00115
00116
00117 void clsfy_binary_hyperplane_logit_builder::b_read(vsl_b_istream &bfs)
00118 {
00119 if (!bfs) return;
00120
00121 short version;
00122 vsl_b_read(bfs,version);
00123 switch (version)
00124 {
00125 case (1):
00126 clsfy_binary_hyperplane_ls_builder::b_read(bfs);
00127 vsl_b_read(bfs,min_p_);
00128 vsl_b_read(bfs,alpha_);
00129 break;
00130 default:
00131 vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_logit_builder::b_read(vsl_b_istream&)\n"
00132 << " Unknown version number "<< version << '\n';
00133 bfs.is().clear(vcl_ios::badbit);
00134 }
00135 }
00136
00137
00138
00139 vcl_string clsfy_binary_hyperplane_logit_builder::is_a() const
00140 {
00141 return vcl_string("clsfy_binary_hyperplane_logit_builder");
00142 }
00143
00144
00145
00146 bool clsfy_binary_hyperplane_logit_builder::is_class(vcl_string const& s) const
00147 {
00148 return s == clsfy_binary_hyperplane_logit_builder::is_a() || clsfy_binary_hyperplane_ls_builder::is_class(s);
00149 }
00150
00151
00152
00153 short clsfy_binary_hyperplane_logit_builder::version_no() const
00154 {
00155 return 1;
00156 }
00157
00158
00159
00160 void clsfy_binary_hyperplane_logit_builder::print_summary(vcl_ostream& os) const
00161 {
00162 os << is_a();
00163 }
00164
00165
00166 clsfy_builder_base* clsfy_binary_hyperplane_logit_builder::clone() const
00167 {
00168 return new clsfy_binary_hyperplane_logit_builder(*this);
00169 }
00170