Go to the documentation of this file.00001
00002
00003 #include "clsfy_binary_hyperplane_ls_builder.h"
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <vcl_string.h>
00013 #include <vcl_iostream.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_cassert.h>
00016 #include <vcl_algorithm.h>
00017 #include <vnl/algo/vnl_svd.h>
00018 #include <vnl/vnl_math.h>
00019
00020
00021
00022 vcl_string clsfy_binary_hyperplane_ls_builder::is_a() const
00023 {
00024 return vcl_string("clsfy_binary_hyperplane_ls_builder");
00025 }
00026
00027
00028
00029 bool clsfy_binary_hyperplane_ls_builder::is_class(vcl_string const& s) const
00030 {
00031 return s == clsfy_binary_hyperplane_ls_builder::is_a() || clsfy_builder_base::is_class(s);
00032 }
00033
00034
00035
00036 void clsfy_binary_hyperplane_ls_builder::print_summary(vcl_ostream& os) const
00037 {
00038 os << is_a();
00039 }
00040
00041
00042
00043
00044 double clsfy_binary_hyperplane_ls_builder::build(
00045 clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00046 const vcl_vector<unsigned> &outputs) const
00047 {
00048 assert(outputs.size() == inputs.size());
00049 assert(* vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00050 assert(classifier.is_class("clsfy_binary_hyperplane"));
00051
00052 clsfy_binary_hyperplane &hyperplane = (clsfy_binary_hyperplane &) classifier;
00053
00054 inputs.reset();
00055 const unsigned k = inputs.current().size();
00056 vnl_matrix<double> XtX(k+1, k+1, 0.0);
00057 vnl_vector<double> XtY(k+1, 0.0);
00058
00059 #if 0 // The calculation is as follows
00060 do
00061 {
00062
00063 const vnl_vector<double> &x=inputs.current();
00064 double y = outputs[inputs.index()] ? 1.0 : -1.0;
00065 vnl_vector<double> xp(k+1);
00066 xp.update(x, 0);
00067 xp(k) = -1.0;
00068 XtX += outer_product(xp, xp);
00069 double y = outputs[inputs.index()] ? 1.0 : -1.0;
00070 XtY += y * xp;
00071 } while (inputs.next());
00072 #else// However the following version is faster
00073 do
00074 {
00075
00076 const vnl_vector<double> &x=inputs.current();
00077 double y = outputs[inputs.index()] ? 1.0 : -1.0;
00078 for (unsigned i=0; i<k; ++i)
00079 {
00080 for (unsigned j=0; j<i; ++j)
00081 XtX(i,j) += x(i) * x(j);
00082 XtX(i,i) += vnl_math_sqr(x(i));
00083 XtX(i,k) -= x(i);
00084 XtY(i) += y * x(i);
00085 }
00086 XtY(k) += y * -1.0;
00087
00088 } while (inputs.next());
00089 for (unsigned i=0; i<k; ++i)
00090 {
00091 for (unsigned j=0; j<i; ++j)
00092 XtX(j,i) += XtX(i,j);
00093 XtX(k,i) = XtX(i,k);
00094 }
00095 XtX(k, k) = (double) inputs.size();
00096 #endif
00097
00098
00099
00100
00101
00102
00103 vnl_svd<double> svd(XtX, 1.0e-12);
00104 vnl_vector<double> w = svd.solve(XtY);
00105 #if 0
00106 vcl_cerr << "XtX: " << XtX << vcl_endl
00107 << "XtY: " << XtY << vcl_endl
00108 << "w: " << w << vcl_endl;
00109 #endif
00110 vnl_vector<double> weights(&w(0), k);
00111 hyperplane.set(weights, w(k));
00112
00113 return clsfy_test_error(classifier, inputs, outputs);
00114 }
00115
00116
00117
00118
00119
00120
00121
00122
00123 double clsfy_binary_hyperplane_ls_builder::build(
00124 clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00125 unsigned n_classes, const vcl_vector<unsigned> &outputs) const
00126 {
00127 assert (n_classes == 1);
00128 return build(classifier, inputs, outputs);
00129 }
00130
00131
00132
00133 void clsfy_binary_hyperplane_ls_builder::b_write(vsl_b_ostream &bfs) const
00134 {
00135 const short version_no=1;
00136 vsl_b_write(bfs, version_no);
00137 }
00138
00139
00140
00141 void clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream &bfs)
00142 {
00143 if (!bfs) return;
00144
00145 short version;
00146 vsl_b_read(bfs,version);
00147 switch (version)
00148 {
00149 case (1):
00150 break;
00151 default:
00152 vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream&)\n"
00153 << " Unknown version number "<< version << '\n';
00154 bfs.is().clear(vcl_ios::badbit);
00155 }
00156 }