Go to the documentation of this file.00001
00002
00003 #include "clsfy_rbf_svm.h"
00004
00005
00006
00007
00008
00009
00010 #include <vcl_string.h>
00011 #include <vcl_cassert.h>
00012 #include <vcl_cstdlib.h>
00013 #include <vsl/vsl_indent.h>
00014 #include <vsl/vsl_vector_io.h>
00015 #include <vnl/io/vnl_io_vector.h>
00016 #include <vnl/vnl_math.h>
00017
00018
00019
00020 clsfy_rbf_svm::clsfy_rbf_svm()
00021 {
00022 }
00023
00024
00025
00026 clsfy_rbf_svm::~clsfy_rbf_svm()
00027 {
00028 }
00029
00030
00031
00032
00033 double clsfy_rbf_svm::kernel(const vnl_vector<double> &v1,
00034 const vnl_vector<double> &v2) const
00035 {
00036 return vcl_exp(gamma_*vnl_vector_ssd(v1, v2));
00037 }
00038
00039
00040
00041
00042
00043
00044 unsigned clsfy_rbf_svm::classify(const vnl_vector<double> &input) const
00045 {
00046 int n = supports_.size();
00047 double sum =- bias_;
00048 double upper_target = upper_target_;
00049 double lower_target = lower_target_;
00050 int i;
00051 for (i =0; i<n; i++)
00052 {
00053 const double l = lagrangians_[i];
00054 if (l <0) upper_target += l;
00055 else lower_target += l;
00056 sum += l * vcl_exp(gamma_*localEuclideanDistanceSq(input, supports_[i]));
00057 if (sum > upper_target) return 1u;
00058 else if (sum < lower_target) return 0u;
00059 }
00060 vcl_cerr << "ERROR: clsfy_rbf_svm::classify"
00061 << " Should not have reached here\n";
00062 vcl_abort();
00063 return 0u;
00064 }
00065
00066
00067
00068
00069
00070
00071
00072 double clsfy_rbf_svm::log_l(const vnl_vector<double> &input) const
00073 {
00074 int n = supports_.size();
00075 double sum =0.0;
00076 for (int i =0; i<n; i++)
00077 sum += lagrangians_[i] * vcl_exp(gamma_*vnl_vector_ssd(input, supports_[i]));
00078 return sum - bias_;
00079 }
00080
00081
00082
00083
00084
00085
00086
00087
00088 void clsfy_rbf_svm::class_probabilities(vcl_vector<double> &outputs,
00089 const vnl_vector<double> &input) const
00090 {
00091 outputs.resize(1);
00092 double Likely = vcl_exp(log_l(input));
00093 if (Likely == vnl_huge_val(double()))
00094 outputs[0] = 1;
00095 else
00096 outputs[0] = Likely / (1+Likely);
00097 return;
00098 }
00099
00100
00101
00102
00103 void clsfy_rbf_svm::calculate_targets()
00104 {
00105 upper_target_ = lower_target_=0;
00106
00107 const unsigned n = supports_.size();
00108 for (unsigned i =0;i<n;i++)
00109 {
00110 const double l = lagrangians()[i];
00111 if (l < 0) upper_target_ -= l;
00112 else lower_target_ -= l;
00113 }
00114 }
00115
00116
00117
00118
00119
00120 void clsfy_rbf_svm::set( const vcl_vector<vnl_vector<double> > &supportVectors,
00121 const vcl_vector<double> &lagrangianAlphas,
00122 const vcl_vector<unsigned> &labels,
00123 double RBFWidth, double bias)
00124 {
00125 unsigned int n = supportVectors.size();
00126 assert(n == lagrangianAlphas.size());
00127 assert(n == labels.size());
00128 supports_ = supportVectors;
00129
00130
00131 lagrangians_ = lagrangianAlphas;
00132 for (unsigned int i=0; i<n; i++)
00133 lagrangians_[i] *= (labels[i]?1:-1);
00134
00135 gamma_ = -0.5/(RBFWidth*RBFWidth);
00136 bias_ = bias;
00137 calculate_targets();
00138 }
00139
00140
00141
00142
00143 vcl_string clsfy_rbf_svm::is_a() const
00144 {
00145 return vcl_string("clsfy_rbf_svm");
00146 }
00147
00148
00149
00150 bool clsfy_rbf_svm::is_class(vcl_string const& s) const
00151 {
00152 return s == clsfy_rbf_svm::is_a() || clsfy_classifier_base::is_class(s);
00153 }
00154
00155
00156
00157 short clsfy_rbf_svm::version_no() const
00158 {
00159 return 2;
00160 }
00161
00162
00163
00164 clsfy_classifier_base* clsfy_rbf_svm::clone() const
00165 {
00166 return new clsfy_rbf_svm(*this);
00167 }
00168
00169
00170
00171 void clsfy_rbf_svm::print_summary(vcl_ostream& os) const
00172 {
00173 os << vsl_indent() << "bias=" << bias_ << " sigma=" << rbf_width()
00174 << " nSupportVectors=" << n_support_vectors() << '\n'
00175 << vsl_indent() <<" Starting targets are " << upper_target_
00176 << ", " << lower_target_ << vcl_endl;
00177 }
00178
00179
00180
00181 void clsfy_rbf_svm::b_write(vsl_b_ostream& bfs) const
00182 {
00183 vsl_b_write(bfs,version_no());
00184 vsl_b_write(bfs,bias_);
00185 vsl_b_write(bfs,gamma_);
00186 vsl_b_write(bfs,lagrangians_);
00187 vsl_b_write(bfs,supports_);
00188 }
00189
00190
00191
00192 void clsfy_rbf_svm::b_read(vsl_b_istream& bfs)
00193 {
00194 if (!bfs) return;
00195
00196 double dummy;
00197 short version;
00198 vsl_b_read(bfs,version);
00199 switch (version)
00200 {
00201 case 1:
00202 vcl_cerr << "WARNING: clsfy_rbf_svm::b_read().\n"
00203 << "Version 1 shouldn't really be loaded into this class.\n";
00204
00205 vsl_b_read(bfs,bias_);
00206 vsl_b_read(bfs,gamma_);
00207 vsl_b_read(bfs,lagrangians_);
00208 vsl_b_read(bfs,supports_);
00209 vsl_b_read(bfs,dummy);
00210 vsl_b_read(bfs,dummy);
00211 calculate_targets();
00212 break;
00213 case 2:
00214 vsl_b_read(bfs,bias_);
00215 vsl_b_read(bfs,gamma_);
00216 vsl_b_read(bfs,lagrangians_);
00217 vsl_b_read(bfs,supports_);
00218 calculate_targets();
00219 break;
00220 default:
00221 vcl_cerr << "I/O ERROR: clsfy_rbf_svm::b_read(vsl_b_istream&)\n"
00222 << " Unknown version number "<< version << '\n';
00223 bfs.is().clear(vcl_ios::badbit);
00224 return;
00225 }
00226 }