contrib/mul/clsfy/clsfy_rbf_svm.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_rbf_svm.cxx
00002 // Copyright: (C) 2001 British Telecommunications plc
00003 #include "clsfy_rbf_svm.h"
00004 //:
00005 // \file
00006 // \brief Implement a RBF Support Vector Machine
00007 // \author Ian Scott
00008 // \date Jan 2001
00009 
00010 #include <vcl_string.h>
00011 #include <vcl_cassert.h>
00012 #include <vcl_cstdlib.h> // for vcl_abort()
00013 #include <vsl/vsl_indent.h>
00014 #include <vsl/vsl_vector_io.h>
00015 #include <vnl/io/vnl_io_vector.h>
00016 #include <vnl/vnl_math.h>
00017 
00018 //=======================================================================
00019 
00020 clsfy_rbf_svm::clsfy_rbf_svm()
00021 {
00022 }
00023 
00024 //=======================================================================
00025 
00026 clsfy_rbf_svm::~clsfy_rbf_svm()
00027 {
00028 }
00029 
00030 
00031 //==================================================================
00032 
00033 double clsfy_rbf_svm::kernel(const vnl_vector<double> &v1,
00034                              const vnl_vector<double> &v2) const
00035 {
00036   return vcl_exp(gamma_*vnl_vector_ssd(v1, v2));
00037 }
00038 
00039 //=======================================================================
00040 
00041 //: Classify the input vector.
00042 // Returns 0 to indicate out of (or negative) class and one to
00043 // indicate in class (or positive.)
00044 unsigned clsfy_rbf_svm::classify(const vnl_vector<double> &input) const
00045 {
00046   int n = supports_.size();
00047   double sum =- bias_;
00048   double upper_target = upper_target_;
00049   double lower_target = lower_target_;
00050   int i;
00051   for (i =0; i<n; i++)
00052   {
00053     const double l = lagrangians_[i];
00054     if (l <0) upper_target += l;
00055     else lower_target += l;
00056     sum += l * vcl_exp(gamma_*localEuclideanDistanceSq(input, supports_[i]));
00057     if (sum > upper_target) return 1u;
00058     else if (sum < lower_target) return 0u;
00059   }
00060   vcl_cerr << "ERROR: clsfy_rbf_svm::classify"
00061            << " Should not have reached here\n";
00062   vcl_abort();
00063   return 0u;
00064 }
00065 
00066 //=======================================================================
00067 
00068 //: Log likelihood of being in class (binary classifiers only).
00069 // class probability = vcl_exp(logL) / (1+vcl_exp(logL))
00070 // This is not a strict log likelihood value, since SVMs do not give Bayesian
00071 // outputs. However its properties fit the requirements of a log likelihood value.
00072 double clsfy_rbf_svm::log_l(const vnl_vector<double> &input) const
00073 {
00074   int n = supports_.size();
00075   double sum =0.0;
00076   for (int i =0; i<n; i++)
00077     sum += lagrangians_[i] * vcl_exp(gamma_*vnl_vector_ssd(input, supports_[i]));
00078   return sum - bias_;
00079 }
00080 
00081 //=======================================================================
00082 
00083 //: Return the probability the input being in each class.
00084 // output(i) i<<nClasses, contains the probability that the input
00085 // is in class i;
00086 // This are not strict probability values, since SVMs do not give Bayesian outputs. However
00087 // their properties fit the requirements of a probability.
00088 void clsfy_rbf_svm::class_probabilities(vcl_vector<double> &outputs,
00089                                         const vnl_vector<double> &input) const
00090 {
00091   outputs.resize(1);
00092   double Likely = vcl_exp(log_l(input));
00093   if (Likely == vnl_huge_val(double()))
00094     outputs[0] = 1;
00095   else
00096     outputs[0] = Likely / (1+Likely);
00097   return;
00098 }
00099 
00100 //=======================================================================
00101 
00102 //: Set the private target member values to the correct value.
00103 void clsfy_rbf_svm::calculate_targets()
00104 {
00105   upper_target_ = lower_target_=0;
00106 
00107   const unsigned n = supports_.size();
00108   for (unsigned i =0;i<n;i++)
00109   {
00110     const double l = lagrangians()[i];
00111     if (l < 0)  upper_target_ -= l;
00112     else        lower_target_ -= l;
00113   }
00114 }
00115 
00116 
00117 //=======================================================================
00118 
00119 //: Set the internal values defining the classifier.
00120 void clsfy_rbf_svm::set( const vcl_vector<vnl_vector<double> > &supportVectors,
00121                          const vcl_vector<double> &lagrangianAlphas,
00122                          const vcl_vector<unsigned> &labels,
00123                          double RBFWidth, double bias)
00124 {
00125   unsigned int n = supportVectors.size();
00126   assert(n == lagrangianAlphas.size());
00127   assert(n == labels.size());
00128   supports_ = supportVectors;
00129 
00130   // pre-multiply Lagrangians with output labels.
00131   lagrangians_ = lagrangianAlphas;
00132   for (unsigned int i=0; i<n; i++)
00133     lagrangians_[i] *= (labels[i]?1:-1);
00134 
00135   gamma_ = -0.5/(RBFWidth*RBFWidth);
00136   bias_ = bias;
00137   calculate_targets();
00138 }
00139 
00140 
00141 //=======================================================================
00142 
00143 vcl_string clsfy_rbf_svm::is_a() const
00144 {
00145   return vcl_string("clsfy_rbf_svm");
00146 }
00147 
00148 //=======================================================================
00149 
00150 bool clsfy_rbf_svm::is_class(vcl_string const& s) const
00151 {
00152   return s == clsfy_rbf_svm::is_a() || clsfy_classifier_base::is_class(s);
00153 }
00154 
00155 //=======================================================================
00156 
00157 short clsfy_rbf_svm::version_no() const
00158 {
00159   return 2;
00160 }
00161 
00162 //=======================================================================
00163 
00164 clsfy_classifier_base* clsfy_rbf_svm::clone() const
00165 {
00166   return new clsfy_rbf_svm(*this);
00167 }
00168 
00169 //=======================================================================
00170 
00171 void clsfy_rbf_svm::print_summary(vcl_ostream& os) const
00172 {
00173   os << vsl_indent() << "bias=" << bias_ << "  sigma=" << rbf_width()
00174      << "  nSupportVectors=" << n_support_vectors() << '\n'
00175      << vsl_indent() <<"  Starting targets are  " << upper_target_
00176      << ", " << lower_target_ << vcl_endl;
00177 }
00178 
00179 //=======================================================================
00180 
00181 void clsfy_rbf_svm::b_write(vsl_b_ostream& bfs) const
00182 {
00183   vsl_b_write(bfs,version_no());
00184   vsl_b_write(bfs,bias_);
00185   vsl_b_write(bfs,gamma_);
00186   vsl_b_write(bfs,lagrangians_);
00187   vsl_b_write(bfs,supports_);
00188 }
00189 
00190 //=======================================================================
00191 
00192 void clsfy_rbf_svm::b_read(vsl_b_istream& bfs)
00193 {
00194   if (!bfs) return;
00195 
00196   double dummy;
00197   short version;
00198   vsl_b_read(bfs,version);
00199   switch (version)
00200   {
00201    case 1:
00202     vcl_cerr << "WARNING: clsfy_rbf_svm::b_read().\n"
00203              << "Version 1 shouldn't really be loaded into this class.\n";
00204 
00205     vsl_b_read(bfs,bias_);
00206     vsl_b_read(bfs,gamma_);
00207     vsl_b_read(bfs,lagrangians_);
00208     vsl_b_read(bfs,supports_);
00209     vsl_b_read(bfs,dummy);
00210     vsl_b_read(bfs,dummy);
00211     calculate_targets();
00212     break;
00213    case 2:
00214     vsl_b_read(bfs,bias_);
00215     vsl_b_read(bfs,gamma_);
00216     vsl_b_read(bfs,lagrangians_);
00217     vsl_b_read(bfs,supports_);
00218     calculate_targets();
00219     break;
00220    default:
00221     vcl_cerr << "I/O ERROR: clsfy_rbf_svm::b_read(vsl_b_istream&)\n"
00222              << "           Unknown version number "<< version << '\n';
00223     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00224     return;
00225   }
00226 }