Go to the documentation of this file.00001
00002 #include "clsfy_smo_1.h"
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #include <vcl_cmath.h>
00018 #include <vnl/vnl_math.h>
00019 #include <vcl_iostream.h>
00020 #include <vcl_cassert.h>
00021
00022
00023
00024
00025
00026 double clsfy_smo_1_lin::kernel(int i1, int i2)
00027 {
00028 if (i1==i2)
00029 return precomputed_self_dot_product_[i1];
00030 else
00031 return dot_product(data_point(i1),data_point(i2));
00032 }
00033
00034
00035
00036 double clsfy_smo_1_rbf::kernel(int i1, int i2)
00037 {
00038 if (i1==i2) return 1.0;
00039 double s = dot_product(data_point(i1),data_point(i2));
00040 s *= -2.0f;
00041 s += precomputed_self_dot_product_[i1] + precomputed_self_dot_product_[i2];
00042 return vcl_exp(gamma_ * s);
00043 }
00044
00045
00046
00047
00048
00049 void clsfy_smo_1_lin::set_data(const mbl_data_wrapper<vnl_vector<double> >& data, const vcl_vector<int>& targets)
00050 {
00051 const unsigned N = data.size();
00052 data_ = data.clone();
00053
00054 assert(targets.size() == N);
00055 target_ = targets;
00056
00057 precomputed_self_dot_product_.resize(N);
00058 for (unsigned int i=0; i<N; i++)
00059 precomputed_self_dot_product_[i] = dot_product(data_point(i),data_point(i));
00060 }
00061
00062
00063
00064 double clsfy_smo_1_lin::C() const
00065 {
00066 return C_;
00067 }
00068
00069
00070
00071 void clsfy_smo_1_lin::set_C(double C)
00072 {
00073 if (C <= 0) C_ = vnl_huge_val(double());
00074 else C_ = C;
00075 }
00076
00077
00078
00079
00080 double clsfy_smo_1_rbf::gamma() const
00081 {
00082 return -gamma_;
00083 }
00084
00085
00086
00087
00088
00089 void clsfy_smo_1_rbf::set_gamma(double gamma)
00090 {
00091 gamma_ = -gamma;
00092 }
00093
00094
00095
00096 clsfy_smo_1_rbf::clsfy_smo_1_rbf():
00097 gamma_((double)-0.5)
00098 {
00099 }
00100
00101
00102
00103 clsfy_smo_1_lin::clsfy_smo_1_lin():
00104 C_(vnl_huge_val(double()))
00105 {
00106 }
00107
00108
00109
00110 int clsfy_smo_1_lin::take_step(int i1, int i2, double E1)
00111 {
00112 int s;
00113 double a1, a2;
00114 double E2, L, H, k11, k22, k12, eta, Lobj, Hobj;
00115
00116 if (i1 == i2) return 0;
00117
00118
00119 const double alph1 = alph_[i1];
00120 const int y1 = target_[i1];
00121
00122 const double alph2 = alph_[i2];
00123 const int y2 = target_[i2];
00124 if (alph2 > 0 && alph2 < C_)
00125 E2 = error_cache_[i2];
00126 else
00127 E2 = learned_func(i2) - y2;
00128
00129 s = y1 * y2;
00130
00131 if (y1 == y2) {
00132 const double g = alph1 + alph2;
00133 if (g > C_) {
00134 L = g-C_;
00135 H = C_;
00136 }
00137 else {
00138 L = 0;
00139 H = g;
00140 }
00141 }
00142 else
00143 {
00144 const double g = alph1 - alph2;
00145 if (g > 0) {
00146 L = 0;
00147 H = C_ - g;
00148 }
00149 else {
00150 L = -g;
00151 H = C_;
00152 }
00153 }
00154
00155 if (L == H)
00156 return 0;
00157
00158 k11 = kernel(i1, i1);
00159 k12 = kernel(i1, i2);
00160 k22 = kernel(i2, i2);
00161 eta = 2.0 * k12 - k11 - k22;
00162
00163
00164 if (eta < 0) {
00165 a2 = alph2 + y2 * (E2 - E1) / eta;
00166 if (a2 < L)
00167 a2 = L;
00168 else if (a2 > H)
00169 a2 = H;
00170 }
00171 else {
00172 {
00173 double c1 = eta/2;
00174 double c2 = y2 * (E1-E2)- eta * alph2;
00175 Lobj = c1 * L * L + c2 * L;
00176 Hobj = c1 * H * H + c2 * H;
00177 }
00178
00179 if (Lobj > Hobj+eps_)
00180 a2 = L;
00181 else if (Lobj < Hobj-eps_)
00182 a2 = H;
00183 else
00184 a2 = alph2;
00185 }
00186
00187 if (vnl_math_abs(a2-alph2) < eps_*(a2+alph2+eps_) )
00188 return 0;
00189
00190 a1 = alph1 - s * (a2 - alph2);
00191 if (a1 < 0.0) {
00192 a2 += s * a1;
00193 a1 = 0;
00194 }
00195 else if (a1 > C_) {
00196 double t = a1-C_;
00197 a2 += s * t;
00198 a1 = C_;
00199 }
00200
00201
00202 double delta_b;
00203 {
00204 double b1, b2, bnew;
00205
00206 const double eps_2 = eps_*eps_;
00207
00208 if (a1 > eps_2 && a1 < (C_*(1-eps_2)))
00209 bnew = b_ + E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
00210 else {
00211 if (a2 > eps_2 && a2 < (C_*(1-eps_2)))
00212 bnew = b_ + E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
00213 else {
00214 b1 = b_ + E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
00215 b2 = b_ + E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
00216 bnew = (b1 + b2) / 2;
00217 }
00218 }
00219 delta_b = bnew - b_;
00220 b_ = bnew;
00221 }
00222
00223 {
00224 const double t1 = y1 * (a1-alph1);
00225 const double t2 = y2 * (a2-alph2);
00226
00227 for (unsigned int i=0; i<data_->size(); i++)
00228 if (0 < alph_[i] && alph_[i] < C_)
00229 error_cache_[i] += t1 * kernel(i1,i) + t2 * kernel(i2,i)
00230 - delta_b;
00231 error_cache_[i1] = 0.0;
00232 error_cache_[i2] = 0.0;
00233 }
00234
00235 alph_[i1] = a1;
00236 alph_[i2] = a2;
00237
00238 return 1;
00239 }
00240
00241
00242
00243 int clsfy_smo_1_lin::examine_example(int i1)
00244 {
00245 double E1, r1;
00246 const unsigned long N = data_->size();
00247
00248 const double y1 = target_[i1];
00249 const double alph1 = alph_(i1);
00250
00251 if (alph1 > 0 && alph1 < C_)
00252 E1 = error_cache_[i1];
00253 else
00254 E1 = learned_func(i1) - y1;
00255
00256 r1 = y1 * E1;
00257 if ((r1 < -tolerance_ && alph1 < C_) ||
00258 (r1 > tolerance_ && alph1 > 0))
00259 {
00260
00261 {
00262 unsigned int k;
00263 int i2;
00264 double tmax;
00265
00266
00267
00268
00269 for (i2 = (-1), tmax = 0, k = 0; k < N; ++k)
00270 if (alph_(k) > 0 && alph_(k) < C_)
00271 {
00272 double E2, temp;
00273
00274 E2 = error_cache_[k];
00275 temp = vnl_math_abs(E1 - E2);
00276 if (temp > tmax)
00277 {
00278 tmax = temp;
00279 i2 = k;
00280 }
00281 }
00282
00283 if (i2 >= 0) {
00284 if (take_step (i1, i2, E1))
00285 return 1;
00286 }
00287 }
00288
00289
00290
00291
00292 for (unsigned long k0 = rng_.lrand32(N-1), k = k0; k < N + k0; ++k)
00293 {
00294 unsigned long i2 = k % N;
00295 if (alph_(i2) > 0 && alph_(i2) < C_)
00296 {
00297 if (take_step(i1, i2, E1))
00298 return 1;
00299 }
00300 }
00301
00302
00303
00304
00305 for (unsigned long k0 = rng_.lrand32(N-1), k = k0; k < N + k0; ++k)
00306 {
00307 unsigned long i2 = k % N;
00308 if (alph_(i2) == 0 || alph_(i2) == C_)
00309 {
00310 if (take_step(i1, i2, E1))
00311 return 1;
00312 }
00313 }
00314 }
00315 return 0;
00316 }
00317
00318
00319
00320 int clsfy_smo_1_rbf::calc()
00321 {
00322 assert(gamma_!=0.0);
00323 return clsfy_smo_1_lin::calc();
00324 }
00325
00326
00327
00328 int clsfy_smo_1_lin::calc()
00329 {
00330
00331
00332 assert (data_ != 0);
00333
00334 const unsigned long N = data_->size();
00335 assert(N != 0);
00336
00337 if (alph_.empty())
00338 {
00339 alph_.set_size(N);
00340 alph_.fill(0.0);
00341 }
00342
00343
00344
00345 error_cache_.resize(N);
00346
00347 unsigned long numChanged = 0;
00348 bool examineAll = true;
00349 while (numChanged > 0 || examineAll)
00350 {
00351 numChanged = 0;
00352 if (examineAll)
00353 for (unsigned int k = 0; k < N; k++)
00354 numChanged += examine_example (k);
00355 else
00356 for (unsigned int k = 0; k < N; k++)
00357 if (alph_[k] != 0 && alph_[k] != C_)
00358 numChanged += examine_example (k);
00359 if (examineAll)
00360 examineAll = false;
00361 else if (numChanged == 0)
00362 examineAll = true;
00363
00364 #if !defined NDEBUG && CLSFY_SMO_BASE_PRINT_PROGRESS >1
00365 {
00366 double s = 0.;
00367 for (int i=0; i<N; i++)
00368 s += alph_[i];
00369 double t = 0.;
00370 for (int i=0; i<N; i++)
00371 {
00372 if (alph_(i) != 0.0)
00373 {
00374 for (int j=0; j<N; j++)
00375 if (alph_[j] != 0.0)
00376 t += alph_[i]*alph_[j]*target_[i]*target_[j]*kernel(i,j);
00377 }
00378 }
00379 vcl_cerr << "Objective function=" << (s - t/2.) << '\t';
00380 for (int i=0; i<N; i++)
00381 if (alph_[i] < 0)
00382 vcl_cerr << "alph_[" << i << "]=" << alph_[i] << " < 0\n";
00383 s = 0.;
00384 for (int i=0; i<N; i++)
00385 s += alph_[i] * target_[i];
00386 vcl_cerr << "s=" << s << "\terror_rate=" << error_rate() << '\t';
00387 }
00388 #endif
00389
00390 #if !defined NDEBUG && CLSFY_SMO_BASE_PRINT_PROGRESS
00391 {
00392 int non_bound_support =0;
00393 int bound_support =0;
00394 for (int i=0; i<N; i++)
00395 if (alph_[i] > 0)
00396 {
00397 if (alph_[i] < C_)
00398 non_bound_support++;
00399 else
00400 bound_support++;
00401 }
00402 vcl_cerr << "non_bound=" << non_bound_support << '\t'
00403 << "bound_support=" << bound_support << vcl_endl;
00404 }
00405 #endif
00406 }
00407
00408 error_ = error_rate();
00409
00410 #if !defined NDEBUG && CLSFY_SMO_BASE_PRINT_PROGRESS
00411 vcl_cerr << "Threshold=" << b_ << vcl_endl;
00412 vcl_cout << "Error rate=" << error_ << vcl_endl;
00413 #endif
00414
00415 return 0;
00416 }