contrib/mul/clsfy/clsfy_binary_hyperplane_ls_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_hyperplane_ls_builder.cxx
00002 // Copyright: (C) 2001 British Telecommunications PLC
00003 #include "clsfy_binary_hyperplane_ls_builder.h"
00004 //:
00005 // \file
00006 // \brief Implement a two-class output linear classifier builder
00007 // \author Ian Scott
00008 // \date 4 June 2001
00009 
00010 //=======================================================================
00011 
00012 #include <vcl_string.h>
00013 #include <vcl_iostream.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_cassert.h>
00016 #include <vcl_algorithm.h>
00017 #include <vnl/algo/vnl_svd.h>
00018 #include <vnl/vnl_math.h>
00019 
00020 //=======================================================================
00021 
00022 vcl_string clsfy_binary_hyperplane_ls_builder::is_a() const
00023 {
00024   return vcl_string("clsfy_binary_hyperplane_ls_builder");
00025 }
00026 
00027 //=======================================================================
00028 
00029 bool clsfy_binary_hyperplane_ls_builder::is_class(vcl_string const& s) const
00030 {
00031   return s == clsfy_binary_hyperplane_ls_builder::is_a() || clsfy_builder_base::is_class(s);
00032 }
00033 
00034 //=======================================================================
00035 
00036 void clsfy_binary_hyperplane_ls_builder::print_summary(vcl_ostream& os) const
00037 {
00038   os << is_a();
00039 }
00040 
00041 //=======================================================================
00042 
00043 //: Build a multi layer perceptron classifier, with the given data.
00044 double clsfy_binary_hyperplane_ls_builder::build(
00045   clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00046   const vcl_vector<unsigned> &outputs) const
00047 {
00048   assert(outputs.size() == inputs.size());
00049   assert(* vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00050   assert(classifier.is_class("clsfy_binary_hyperplane"));
00051 
00052   clsfy_binary_hyperplane &hyperplane = (clsfy_binary_hyperplane &) classifier;
00053 
00054   inputs.reset();
00055   const unsigned k = inputs.current().size();
00056   vnl_matrix<double> XtX(k+1, k+1, 0.0);
00057   vnl_vector<double> XtY(k+1, 0.0);
00058 
00059 #if 0 // The calculation is as follows
00060   do
00061   {
00062     // XtX += [x, -1]' * [x, -1]
00063     const vnl_vector<double> &x=inputs.current();
00064     double y = outputs[inputs.index()] ? 1.0 : -1.0;
00065     vnl_vector<double> xp(k+1);
00066     xp.update(x, 0);
00067     xp(k) = -1.0;
00068     XtX += outer_product(xp, xp);
00069     double y = outputs[inputs.index()] ? 1.0 : -1.0;
00070     XtY += y * xp;
00071   } while (inputs.next());
00072 #else// However the following version is faster
00073   do
00074   {
00075     // XtX += [x, -1]' * [x, -1]
00076     const vnl_vector<double> &x=inputs.current();
00077     double y = outputs[inputs.index()] ? 1.0 : -1.0;
00078     for (unsigned i=0; i<k; ++i)
00079     {
00080       for (unsigned j=0; j<i; ++j)
00081         XtX(i,j) += x(i) * x(j);
00082       XtX(i,i) += vnl_math_sqr(x(i));
00083       XtX(i,k) -= x(i);
00084       XtY(i) += y * x(i);
00085     }
00086     XtY(k) += y * -1.0;
00087 
00088   } while (inputs.next());
00089   for (unsigned i=0; i<k; ++i)
00090   {
00091     for (unsigned j=0; j<i; ++j)
00092       XtX(j,i) += XtX(i,j);
00093     XtX(k,i) = XtX(i,k);
00094   }
00095   XtX(k, k) = (double) inputs.size();
00096 #endif
00097 
00098 
00099   // Find the solution to X w = Y;
00100   // However it is easier to find X' X w = X' Y;
00101   // because X is n_train x n_elems whereas X'X is n_elems x n_elems
00102 
00103   vnl_svd<double> svd(XtX, 1.0e-12); // 1e-12 = zero-tolerance for singular values
00104   vnl_vector<double> w = svd.solve(XtY);
00105 #if 0
00106   vcl_cerr << "XtX: " << XtX << vcl_endl
00107            << "XtY: " << XtY << vcl_endl
00108            << "w: "   << w   << vcl_endl;
00109 #endif
00110   vnl_vector<double> weights(&w(0), k);
00111   hyperplane.set(weights, w(k));
00112 
00113   return clsfy_test_error(classifier, inputs, outputs);
00114 }
00115 
00116 
00117 //=======================================================================
00118 
00119 
00120 //: Build a linear classifier, with the given data.
00121 // Return the mean error over the training set.
00122 // n_classes must be 1.
00123 double clsfy_binary_hyperplane_ls_builder::build(
00124   clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00125   unsigned n_classes, const vcl_vector<unsigned> &outputs) const
00126 {
00127   assert (n_classes == 1);
00128   return build(classifier, inputs, outputs);
00129 }
00130 
00131 //=======================================================================
00132 
00133 void clsfy_binary_hyperplane_ls_builder::b_write(vsl_b_ostream &bfs) const
00134 {
00135   const short version_no=1;
00136   vsl_b_write(bfs, version_no);
00137 }
00138 
00139 //=======================================================================
00140 
00141 void clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream &bfs)
00142 {
00143   if (!bfs) return;
00144 
00145   short version;
00146   vsl_b_read(bfs,version);
00147   switch (version)
00148   {
00149     case (1):
00150       break;
00151     default:
00152       vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream&)\n"
00153                << "           Unknown version number "<< version << '\n';
00154       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00155   }
00156 }