contrib/mul/clsfy/clsfy_smo_1.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_smo_1.cxx
00002 #include "clsfy_smo_1.h"
00003 //:
00004 // \file
00005 // \author Ian Scott
00006 // \date 14-Nov-2001
00007 // \brief Sequential Minimum Optimisation algorithm
00008 // This code is based on the C++ code of
00009 // Xianping Ge, ( http://www.ics.uci.edu/~xge ) which he kindly
00010 // put in the public domain.
00011 // That code was in turn based on the algorithms of
00012 // John Platt, ( http://research.microsoft.com/~jplatt ) described in
00013 // Platt, J. C. (1998). Fast Training of Support Vector Machines Using Sequential
00014 // Minimal Optimisation. In Advances in Kernel Methods - Support Vector Learning.
00015 // B. Scholkopf, C. Burges and A. Smola, MIT Press: 185-208. and other papers.
00016 
00017 #include <vcl_cmath.h>
00018 #include <vnl/vnl_math.h>
00019 #include <vcl_iostream.h>
00020 #include <vcl_cassert.h>
00021 
00022 // Linear SMO
00023 
00024 // ----------------------------------------------------------------
00025 
00026 double clsfy_smo_1_lin::kernel(int i1, int i2)
00027 {
00028   if (i1==i2)
00029     return precomputed_self_dot_product_[i1];
00030   else
00031     return dot_product(data_point(i1),data_point(i2));
00032 }
00033 
00034 // ----------------------------------------------------------------
00035 
00036 double clsfy_smo_1_rbf::kernel(int i1, int i2)
00037 {
00038   if (i1==i2) return 1.0;
00039   double s = dot_product(data_point(i1),data_point(i2));
00040   s *= -2.0f;
00041   s += precomputed_self_dot_product_[i1] + precomputed_self_dot_product_[i2];
00042   return vcl_exp(gamma_ * s);
00043 }
00044 
00045 // ----------------------------------------------------------------
00046 
00047 //: Takes a copy of the data wrapper, but not the data.
00048 // Be careful not to destroy the underlying data while using this object.
00049 void clsfy_smo_1_lin::set_data(const mbl_data_wrapper<vnl_vector<double> >& data, const vcl_vector<int>& targets)
00050 {
00051   const unsigned N = data.size();
00052   data_ = data.clone();
00053 
00054   assert(targets.size() == N);
00055   target_ = targets;
00056 
00057   precomputed_self_dot_product_.resize(N);
00058   for (unsigned int i=0; i<N; i++)
00059       precomputed_self_dot_product_[i] = dot_product(data_point(i),data_point(i));
00060 }
00061 
00062 // ----------------------------------------------------------------
00063 
00064 double clsfy_smo_1_lin::C() const
00065 {
00066   return C_;
00067 }
00068 
00069 // ----------------------------------------------------------------
00070 
00071 void clsfy_smo_1_lin::set_C(double C)
00072 {
00073   if (C <= 0) C_ = vnl_huge_val(double());
00074   else C_ = C;
00075 }
00076 
00077 // ----------------------------------------------------------------
00078 
00079 //: 0.5 sigma^-2, where sigma is the width of the Gaussian kernel
00080 double clsfy_smo_1_rbf::gamma() const
00081 {
00082   return -gamma_;
00083 }
00084 
00085 // ----------------------------------------------------------------
00086 
00087 //: Control sigma, the width of the Gaussian kernel.
00088 // gamma = 0.5 sigma^-2
00089 void clsfy_smo_1_rbf::set_gamma(double gamma)
00090 {
00091   gamma_ = -gamma;
00092 }
00093 
00094 // ----------------------------------------------------------------
00095 
00096 clsfy_smo_1_rbf::clsfy_smo_1_rbf():
00097   gamma_((double)-0.5)
00098 {
00099 }
00100 
00101 // ----------------------------------------------------------------
00102 
00103 clsfy_smo_1_lin::clsfy_smo_1_lin():
00104   C_(vnl_huge_val(double()))
00105 {
00106 }
00107 
00108 // ----------------------------------------------------------------
00109 
00110 int clsfy_smo_1_lin::take_step(int i1, int i2, double E1)
00111 {
00112   int  s;
00113   double a1, a2;       /// new values of alpha_1, alpha_2
00114   double E2, L, H, k11, k22, k12, eta, Lobj, Hobj;
00115 
00116   if (i1 == i2) return 0;
00117 
00118 
00119   const double alph1 = alph_[i1]; // old_values of alpha_1
00120   const int y1 = target_[i1];
00121 
00122   const double alph2 = alph_[i2]; // old_values of alpha_2
00123   const int y2 = target_[i2];
00124   if (alph2 > 0 && alph2 < C_)
00125     E2 = error_cache_[i2];
00126   else
00127     E2 = learned_func(i2) - y2;
00128 
00129   s = y1 * y2;
00130 
00131   if (y1 == y2) {
00132     const double g = alph1 + alph2;
00133     if (g > C_) {
00134         L = g-C_;
00135         H = C_;
00136     }
00137     else {
00138         L = 0;
00139         H = g;
00140     }
00141   }
00142   else
00143   {
00144     const double g = alph1 - alph2;
00145     if (g > 0) {
00146         L = 0;
00147         H = C_ - g;
00148     }
00149     else {
00150         L = -g;
00151         H = C_;
00152     }
00153   }
00154 
00155   if (L == H)
00156     return 0;
00157 
00158   k11 = kernel(i1, i1);
00159   k12 = kernel(i1, i2);
00160   k22 = kernel(i2, i2);
00161   eta = 2.0 * k12 - k11 - k22;
00162 
00163 
00164   if (eta < 0) {
00165     a2 = alph2 + y2 * (E2 - E1) / eta;
00166     if (a2 < L)
00167       a2 = L;
00168     else if (a2 > H)
00169       a2 = H;
00170   }
00171   else {
00172     {
00173       double c1 = eta/2;
00174       double c2 = y2 * (E1-E2)- eta * alph2;
00175       Lobj = c1 * L * L + c2 * L;
00176       Hobj = c1 * H * H + c2 * H;
00177     }
00178 
00179     if (Lobj > Hobj+eps_)
00180       a2 = L;
00181     else if (Lobj < Hobj-eps_)
00182       a2 = H;
00183     else
00184       a2 = alph2;
00185   }
00186 
00187   if (vnl_math_abs(a2-alph2) < eps_*(a2+alph2+eps_) )
00188     return 0;
00189 
00190   a1 = alph1 - s * (a2 - alph2);
00191   if (a1 < 0.0) {
00192     a2 += s * a1;
00193     a1 = 0;
00194   }
00195   else if (a1 > C_) {
00196     double t = a1-C_;
00197     a2 += s * t;
00198     a1 = C_;
00199   }
00200 
00201 
00202   double delta_b;
00203   {
00204     double b1, b2, bnew;
00205 
00206     const double eps_2 = eps_*eps_;
00207 
00208     if (a1 > eps_2 && a1 < (C_*(1-eps_2)))
00209       bnew = b_ + E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
00210     else {
00211       if (a2 > eps_2 && a2 < (C_*(1-eps_2)))
00212         bnew = b_ + E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
00213       else {
00214         b1 = b_ + E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
00215         b2 = b_ + E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
00216         bnew = (b1 + b2) / 2;
00217       }
00218     }
00219     delta_b = bnew - b_;
00220     b_ = bnew;
00221   }
00222 
00223   {
00224     const double t1 = y1 * (a1-alph1);
00225     const double t2 = y2 * (a2-alph2);
00226 
00227     for (unsigned int i=0; i<data_->size(); i++)
00228       if (0 < alph_[i] && alph_[i] < C_)
00229         error_cache_[i] +=  t1 * kernel(i1,i) + t2 * kernel(i2,i)
00230         - delta_b;
00231     error_cache_[i1] = 0.0;
00232     error_cache_[i2] = 0.0;
00233   }
00234 
00235   alph_[i1] = a1;  // Store a1 in the alpha array.
00236   alph_[i2] = a2;  // Store a2 in the alpha array.
00237 
00238   return 1;
00239 }
00240 
00241 // ----------------------------------------------------------------
00242 
00243 int clsfy_smo_1_lin::examine_example(int i1)
00244 {
00245   double  E1, r1;
00246   const unsigned long N = data_->size();
00247 
00248   const double y1 = target_[i1];
00249   const double alph1 = alph_(i1);
00250 
00251   if (alph1 > 0 && alph1 < C_)
00252     E1 = error_cache_[i1];
00253   else
00254     E1 = learned_func(i1) - y1;
00255 
00256   r1 = y1 * E1;
00257   if ((r1 < -tolerance_ && alph1 < C_) ||
00258       (r1 > tolerance_ && alph1 > 0)) // is the KKT condition for alph1 broken?
00259   {
00260     // Try i2 by three ways; if successful, then immediately return 1;
00261     {
00262       unsigned int k;
00263       int i2;
00264       double tmax;
00265 
00266       // Second choice heuristic A - Find the example i2 which maximises
00267       // |E1 - E2| where E1
00268 
00269       for (i2 = (-1), tmax = 0, k = 0; k < N; ++k)
00270         if (alph_(k) > 0 && alph_(k) < C_)
00271         {
00272           double E2, temp;
00273 
00274           E2 = error_cache_[k];
00275           temp = vnl_math_abs(E1 - E2);
00276           if (temp > tmax)
00277           {
00278             tmax = temp;
00279             i2 = k;
00280           }
00281         }
00282 
00283       if (i2 >= 0) {
00284         if (take_step (i1, i2, E1))
00285           return 1;
00286       }
00287     }
00288 
00289 
00290     // second choice Heuristic B - Find any unbound example that give positive progress.
00291     // start from random location
00292     for (unsigned long k0 = rng_.lrand32(N-1), k = k0; k < N + k0; ++k)
00293     {
00294       unsigned long i2 = k % N;
00295       if (alph_(i2) > 0 && alph_(i2) < C_)
00296       {
00297         if (take_step(i1, i2, E1))
00298           return 1;
00299       }
00300     }
00301 
00302 
00303     // second choice Heuristic C - Find any example that give positive progress.
00304     // start from random location
00305     for (unsigned long k0 = rng_.lrand32(N-1), k = k0; k < N + k0; ++k)
00306     {
00307       unsigned long i2 = k % N;
00308       if (alph_(i2) == 0 || alph_(i2) == C_)
00309       {
00310         if (take_step(i1, i2, E1))
00311           return 1;
00312       }
00313     }
00314   }
00315   return 0;
00316 }
00317 
00318 // ----------------------------------------------------------------
00319 
00320 int clsfy_smo_1_rbf::calc()
00321 {
00322   assert(gamma_!=0.0);  //IS gamma set?
00323   return clsfy_smo_1_lin::calc();
00324 }
00325 
00326 // ----------------------------------------------------------------
00327 
00328 int clsfy_smo_1_lin::calc()
00329 {
00330   //Check a bunch of things
00331 
00332   assert (data_ != 0); // Check that the data has been set.
00333 
00334   const unsigned long N = data_->size();
00335   assert(N != 0);     // Check that there is some data.
00336 
00337   if (alph_.empty()) // only initialise alph if it hasn't been externally set.
00338   {
00339     alph_.set_size(N);
00340     alph_.fill(0.0);
00341   }
00342 
00343 
00344   // E_i = u_i - y_i = 0 - y_i = -y_i
00345   error_cache_.resize(N);
00346 
00347   unsigned long numChanged = 0;
00348   bool examineAll = true;
00349   while (numChanged > 0 || examineAll)
00350   {
00351     numChanged = 0;
00352     if (examineAll)
00353       for (unsigned int k = 0; k < N; k++)
00354         numChanged += examine_example (k);
00355     else
00356       for (unsigned int k = 0; k < N; k++)
00357         if (alph_[k] != 0 && alph_[k] != C_)
00358           numChanged += examine_example (k);
00359     if (examineAll)
00360       examineAll = false;
00361     else if (numChanged == 0)
00362       examineAll = true;
00363 
00364 #if !defined NDEBUG &&  CLSFY_SMO_BASE_PRINT_PROGRESS >1
00365     {
00366       double s = 0.;
00367       for (int i=0; i<N; i++)
00368         s += alph_[i];
00369       double t = 0.;
00370       for (int i=0; i<N; i++)
00371       {
00372         if (alph_(i) != 0.0)
00373         {
00374           for (int j=0; j<N; j++)
00375             if (alph_[j] != 0.0)
00376             t += alph_[i]*alph_[j]*target_[i]*target_[j]*kernel(i,j);
00377         }
00378       }
00379       vcl_cerr << "Objective function=" << (s - t/2.) << '\t';
00380       for (int i=0; i<N; i++)
00381         if (alph_[i] < 0)
00382           vcl_cerr << "alph_[" << i << "]=" << alph_[i] << " < 0\n";
00383       s = 0.;
00384       for (int i=0; i<N; i++)
00385         s += alph_[i] * target_[i];
00386       vcl_cerr << "s=" << s << "\terror_rate=" << error_rate() << '\t';
00387     }
00388 #endif
00389 
00390 #if !defined NDEBUG && CLSFY_SMO_BASE_PRINT_PROGRESS
00391     {
00392       int non_bound_support =0;
00393       int bound_support =0;
00394       for (int i=0; i<N; i++)
00395         if (alph_[i] > 0)
00396         {
00397           if (alph_[i] < C_)
00398             non_bound_support++;
00399           else
00400             bound_support++;
00401         }
00402       vcl_cerr << "non_bound=" << non_bound_support << '\t'
00403                << "bound_support=" << bound_support << vcl_endl;
00404     }
00405 #endif
00406   }
00407 
00408   error_ = error_rate();
00409 
00410 #if !defined NDEBUG && CLSFY_SMO_BASE_PRINT_PROGRESS
00411   vcl_cerr << "Threshold=" << b_ << vcl_endl;
00412   vcl_cout << "Error rate=" << error_ << vcl_endl;
00413 #endif
00414 
00415   return 0;
00416 }