contrib/mul/clsfy/clsfy_rbf_svm_smo_1_builder.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_rbf_svm_smo_1_builder.cxx
00002 // Copyright: (C) 2001 British Telecommunications plc.
00003 #include "clsfy_rbf_svm_smo_1_builder.h"
00004 //:
00005 // \file
00006 // \brief Implement an interface to SMO algorithm SVM builder and additional logic
00007 // \author Ian Scott
00008 // \date Dec 2001
00009 
00010 //=======================================================================
00011 
00012 #include <vcl_string.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_sstream.h>
00015 #include <vcl_algorithm.h>
00016 #include <vcl_cassert.h>
00017 #include <vul/vul_string.h>
00018 
00019 #include <mbl/mbl_data_wrapper.h>
00020 #include <mbl/mbl_parse_block.h>
00021 #include <mbl/mbl_read_props.h>
00022 
00023 #include <clsfy/clsfy_smo_1.h>
00024 
00025 //=======================================================================
00026 
00027 inline int class_to_svm_target (unsigned v) {return v==1?1:-1;}
00028 
00029 //=======================================================================
00030 //: Build classifier from data
00031 // returns the training error, or +INF if there is an error.
00032 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00033                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00034                                           const vcl_vector<unsigned> &outputs) const
00035 {
00036   inputs.reset();
00037 //const unsigned int nDims = inputs.current().size(); // unused variable
00038   const unsigned int nSamples = inputs.size();
00039   assert(outputs.size() == nSamples);
00040   assert(*vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00041 
00042   assert(classifier.is_class("clsfy_rbf_svm"));
00043   clsfy_rbf_svm &svm = static_cast<clsfy_rbf_svm &>(classifier);
00044 
00045   clsfy_smo_1_rbf svAPI;
00046   vcl_vector<int> targets(nSamples);
00047   vcl_transform(outputs.begin(), outputs.end(),
00048                 targets.begin(), class_to_svm_target);
00049 
00050   svAPI.set_data(inputs, targets);
00051 
00052 
00053   // Set the SVM solver parameters
00054   svAPI.set_C(boundC_);
00055   svAPI.set_gamma(1.0/(2.0*rbf_width_*rbf_width_));
00056   // Solve the SVM
00057   svAPI.calc();
00058 
00059 
00060   // Get the SVM description, and build an SVM machine
00061   {
00062     vcl_vector<vnl_vector<double> > supportVectors;
00063     const vnl_vector<double> &allAlphas = svAPI.lagrange_mults();
00064     vcl_vector<double> alphas;
00065     vcl_vector<unsigned> labels;
00066     for (unsigned i=0; i<nSamples; ++i)
00067       if (allAlphas[i]!=0.0)
00068       {
00069         alphas.push_back(allAlphas[i]);
00070         labels.push_back(outputs[i]);
00071         inputs.set_index(i);
00072         supportVectors.push_back(inputs.current());
00073       }
00074     svm.set(supportVectors, alphas, labels, rbf_width_, svAPI.bias());
00075   }
00076 
00077   return svAPI.error_rate();
00078 }
00079 
00080 //=======================================================================
00081 //: Build classifier from data.
00082 // returns the training error, or +INF if there is an error.
00083 // nClasses must be 1.
00084 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00085                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00086                                           unsigned nClasses,
00087                                           const vcl_vector<unsigned> &outputs) const
00088 {
00089   assert(nClasses == 1);
00090   return build(classifier, inputs, outputs);
00091 }
00092 
00093 //=======================================================================
00094 
00095 double clsfy_rbf_svm_smo_1_builder::rbf_width() const
00096 {
00097   return rbf_width_;
00098 }
00099 
00100 //=======================================================================
00101 
00102 void clsfy_rbf_svm_smo_1_builder::set_rbf_width(double rbf_width)
00103 {
00104   rbf_width_ = rbf_width;
00105 }
00106 //=======================================================================
00107 
00108 vcl_string clsfy_rbf_svm_smo_1_builder::is_a() const
00109 {
00110   return vcl_string("clsfy_rbf_svm_smo_1_builder");
00111 }
00112 
00113 //=======================================================================
00114 
00115 bool clsfy_rbf_svm_smo_1_builder::is_class(vcl_string const& s) const
00116 {
00117   return s == clsfy_rbf_svm_smo_1_builder::is_a() || clsfy_builder_base::is_class(s);
00118 }
00119 
00120 //=======================================================================
00121 
00122 short clsfy_rbf_svm_smo_1_builder::version_no() const
00123 {
00124   return 1;
00125 }
00126 
00127 //=======================================================================
00128 
00129 clsfy_builder_base* clsfy_rbf_svm_smo_1_builder::clone() const
00130 {
00131   return new clsfy_rbf_svm_smo_1_builder(*this);
00132 }
00133 
00134 //=======================================================================
00135 
00136 void clsfy_rbf_svm_smo_1_builder::print_summary(vcl_ostream& os) const
00137 {
00138   // os << data_; // example of data output
00139   os << "RBF width = " << rbf_width_ << ", bounds = " << boundC_;
00140 }
00141 
00142 //=======================================================================
00143 
00144 void clsfy_rbf_svm_smo_1_builder::b_write(vsl_b_ostream& bfs) const
00145 {
00146   vsl_b_write(bfs,version_no());
00147   vsl_b_write(bfs,boundC_);
00148   vsl_b_write(bfs,rbf_width_);
00149 }
00150 
00151 //=======================================================================
00152 
00153 void clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream& bfs)
00154 {
00155   if (!bfs) return;
00156 
00157   short version;
00158   vsl_b_read(bfs,version);
00159   switch (version)
00160   {
00161   case (1):
00162     vsl_b_read(bfs,boundC_);
00163     vsl_b_read(bfs,rbf_width_);
00164     break;
00165   default:
00166     vcl_cerr << "I/O ERROR: clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream&)\n"
00167              << "           Unknown version number "<< version << '\n';
00168     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00169     return;
00170   }
00171 }
00172 
00173 
00174 //=======================================================================
00175 //: Initialise the parameters from a text stream.
00176 // The next non-ws character in the stream should be a '{'
00177 // \verbatim
00178 // {
00179 //   boundC: 3  (default 0 meaning no bound) Upper bound on the Lagrange multiplies.
00180 //              Smaller non-zero values result in a softening of the boundary.
00181 //
00182 //   rbf_width: 3.0  (required) - A good guess is the mean euclidean distance
00183 //                    to every examples nearest neighbour.
00184 // }
00185 // \endverbatim
00186 // \throw mbl_exception_parse_error if the parse fails.
00187 void clsfy_rbf_svm_smo_1_builder::config(vcl_istream &as)
00188 {
00189  vcl_string s = mbl_parse_block(as);
00190 
00191   vcl_istringstream ss(s);
00192   mbl_read_props_type props = mbl_read_props_ws(ss);
00193 
00194   {
00195     boundC_= vul_string_atof(props.get_optional_property("boundC", "0.0"));
00196     rbf_width_= vul_string_atof(props.get_optional_property("rbf_width", "0.0"));
00197   }
00198 
00199   // Check for unused props
00200   mbl_read_props_look_for_unused_props(
00201     "clsfy_rbf_svm_smo_1_builder::config", props, mbl_read_props_type());
00202 }