contrib/mul/mbl/mbl_lda.cxx
Go to the documentation of this file.
00001 // This is mul/mbl/mbl_lda.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 // \brief  Class to perform linear discriminant analysis
00008 // \author Tim Cootes
00009 //         Converted to VXL by Gavin Wheeler
00010 
00011 #include "mbl_lda.h"
00012 
00013 #include <vcl_algorithm.h>  // for vcl_find
00014 #include <vcl_cassert.h>
00015 #include <vcl_cstddef.h> // for size_t
00016 #include <vcl_cstring.h> // for memcpy()
00017 #include <vsl/vsl_indent.h>
00018 #include <vsl/vsl_vector_io.h>
00019 #include <vsl/vsl_binary_io.h>
00020 #include <vnl/algo/vnl_svd.h>
00021 #include <vnl/algo/vnl_symmetric_eigensystem.h>
00022 #include <vnl/algo/vnl_generalized_eigensystem.h>
00023 #include <vnl/io/vnl_io_vector.h>
00024 #include <mbl/mbl_matxvec.h>
00025 #include <mbl/mbl_log.h>
00026 #include <mbl/mbl_exception.h>
00027 
00028 
00029 //=========================================================================
00030 // Static function to create a static logger when first required
00031 //=========================================================================
00032 static mbl_logger& logger()
00033 {
00034   static mbl_logger l("mul.mbl.lda");
00035   return l;
00036 }
00037 
00038 
00039 //=======================================================================
00040 mbl_lda::mbl_lda()
00041 {
00042 }
00043 
00044 
00045 //=======================================================================
00046 mbl_lda::~mbl_lda()
00047 {
00048 }
00049 
00050 
00051 //=======================================================================
00052 //: Classify a new data point.
00053 // Projects into discriminant space and picks closest mean class vector
00054 int mbl_lda::classify(const vnl_vector<double>& x) const
00055 {
00056   vnl_vector<double> d;
00057   x_to_d(d, x);
00058   int nc=n_classes();
00059   double min_d=(d-d_class_mean(0)).squared_magnitude();
00060   int min_i=0;
00061   for (int i=1; i<nc; ++i)
00062   {
00063     double dist=(d-d_class_mean(i)).squared_magnitude();
00064     if (dist<min_d ) { min_d= dist; min_i=i; }
00065   }
00066   return min_i;
00067 }
00068 
00069 
00070 //=======================================================================
00071 //: Comparison
00072 bool mbl_lda::operator==(const mbl_lda& that) const
00073 {
00074   return mean_ == that.mean_ &&
00075          d_mean_ == that.d_mean_ &&
00076          mean_class_mean_ == that.mean_class_mean_ &&
00077          n_samples_ == that.n_samples_ &&
00078          withinS_ == that.withinS_ &&
00079          betweenS_ == that.betweenS_ &&
00080          basis_ == that.basis_ &&
00081          evals_ == that.evals_ &&
00082          d_m_mean_ == that.d_m_mean_;
00083 }
00084 
00085 
00086 //=======================================================================
00087 void mbl_lda::updateCovar(vnl_matrix<double>& S, const vnl_vector<double>& V)
00088 {
00089   unsigned int n = V.size();
00090   if (S.rows()!=n)
00091   {
00092     S.set_size(n,n);
00093     S.fill(0);
00094   }
00095 
00096   double** s = S.data_array();
00097   const double* v = V.data_block();
00098   for (unsigned int i=0;i<n;++i)
00099   {
00100     double *row = s[i];
00101     double vi = v[i];
00102     for (unsigned int j=0;j<n;++j)
00103       row[j] += vi*v[j];
00104   }
00105 }
00106 
00107 
00108 //=======================================================================
00109 // find out how many id in the label vector
00110 int mbl_lda::nDistinctIDs(const int* id, const int n)
00111 {
00112   vcl_vector<int> dids;
00113   for (int i=0;i<n;++i)
00114   {
00115     if (vcl_find(dids.begin(), dids.end(), id[i])==dids.end())  // if (Index(dids,id[i])<0)
00116       dids.push_back(id[i]);
00117   }
00118 
00119   return dids.size();
00120 }
00121 
00122 
00123 //=======================================================================
00124 //: Perform LDA on data
00125 // \param label  Array [0..n-1] of integers indices
00126 // \param v  Set of vectors [0..n-1]
00127 //
00128 // label[i] gives class of v[i]
00129 // Classes must be labeled from 0 to m-1
00130 void mbl_lda::build(const vnl_vector<double>* v, const int * label, int n,
00131                     const vnl_matrix<double>& wS, bool compute_wS)
00132 {
00133   // Find range of class indices and count #valid
00134   int lo_i=label[0]; // =n causes failure if lo_i is less than n
00135   int hi_i=-1;
00136   int n_valid = 0;
00137   for (int i=0;i<n;++i)
00138   {
00139     if (label[i]>=0)
00140     {
00141       if (label[i]<lo_i) lo_i=label[i];
00142       if (label[i]>hi_i) hi_i=label[i];
00143       n_valid++;
00144     }
00145   }
00146 
00147   //  assert(lo_i==0);
00148 
00149   // Compute mean of each class
00150   int n_classes = nDistinctIDs(label,n);
00151   MBL_LOG(INFO, logger(), "There are " <<n_classes << " classes to build LDA space");
00152   MBL_LOG(INFO, logger(), "Max label index is " << hi_i);
00153   MBL_LOG(INFO, logger(), "Min label index is " << lo_i);
00154 
00155   int n_size=hi_i+1;
00156   mean_.resize(n_size);
00157   n_samples_.resize(n_size);
00158   for (int i=0;i<n_size;++i)
00159     n_samples_[i]=0;
00160 
00161   for (int i=0;i<n;++i)
00162   {
00163     int l = label[i];
00164     if (l<0) continue;
00165     if (mean_[l].size()==0)
00166     {
00167       mean_[l] = v[i];
00168       n_samples_[l] = 1;
00169     }
00170     else
00171     {
00172       mean_[l] += v[i];
00173       n_samples_[l] += 1;
00174     }
00175   }
00176 
00177   int n_used_classes = 0;
00178   for (int i=0;i<n_size;++i)
00179   {
00180     if (n_samples_[i]>0)
00181     {
00182       mean_[i]/=n_samples_[i];
00183       if (i==lo_i) mean_class_mean_ = mean_[i];
00184       else      mean_class_mean_ += mean_[i];
00185       n_used_classes++;
00186     }
00187   }
00188   MBL_LOG(INFO, logger(), "Number of used classes: " << n_used_classes);
00189 
00190   mean_class_mean_/=n_used_classes;
00191 
00192   // Build between class covariance
00193   // Zero to start:
00194   betweenS_.set_size(0,0);
00195 
00196   for (int i=0;i<n_size;++i)
00197   {
00198     if (n_samples_[i]>0)
00199       updateCovar(betweenS_,mean_[i] - mean_class_mean_);
00200   }
00201 
00202   betweenS_/=n_used_classes;
00203 
00204   if (compute_wS)
00205   {
00206     withinS_.set_size(0,0);
00207     // Count number of samples used to build matrix
00208     int n_used=0;
00209     for (int i=0;i<n;++i)
00210     {
00211       int l=label[i];
00212       if (l>=0 && n_samples_[l]>1)
00213       {
00214         updateCovar(withinS_,v[i]-mean_[l]);
00215         n_used++;
00216       }
00217     }
00218     withinS_/=n_used;
00219   }
00220   else
00221     withinS_ = wS;
00222 
00223 #if 0
00224   vnl_matrix<double> wS_inv;
00225   //  NR_Inverse(wS_inv,withinS_);
00226   vnl_svd<double> wS_svd(withinS_, -1.0e-10); // important!!! as the sigma_min=0.0
00227 
00228   wS_inv = wS_svd.inverse();
00229 
00230   vnl_matrix<double> B=withinS_*wS_inv;
00231   vcl_cout<<B<<vcl_endl;
00232 
00233   vnl_matrix<double> A = wS_inv* betweenS_; // was: betweenS_ * wS_inv;
00234 
00235   // Compute eigenvectors and eigenvalues (descending order)
00236   vnl_matrix<double> EVecs(A.rows(), A.columns());
00237   vnl_vector<double> evals(A.columns());
00238   //  NR_CalcSymEigens(A,EVecs,evals,false);
00239 
00240   // **** A not necessarily symmetric!!!! ****
00241   vnl_symmetric_eigensystem_compute(A, EVecs, evals);
00242 #endif // 0
00243 
00244   vnl_generalized_eigensystem gen_eigs(betweenS_,withinS_);
00245   vnl_matrix<double> EVecs= gen_eigs.V;
00246   vnl_vector<double> evals= gen_eigs.D.diagonal();
00247 
00248   // Log some information that might be helpful for debugging
00249   if (logger().level()>=mbl_logger::DEBUG)
00250   {
00251     MBL_LOG(DEBUG, logger(), "eigen decomp in original order:");
00252     unsigned nvec = EVecs.cols();
00253     for (unsigned i=0; i<nvec; ++i)
00254       MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00255               << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00256     for (unsigned i=0; i<nvec; ++i)
00257       MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00258   }
00259 
00260   // Re-arrange the eigenvector matrix (columns) and eigenvalue vector into descending order.
00261   // Assume they are in order of increasing eigenvalue magnitude.
00262   // NB The output from vnl_generalized_eigensystem above will be in order of
00263   // increasing (signed) eigenvalue, not magnitude. If we ever get negative eigenvalues,
00264   // then the simple reversal of flip() and fliplr() will not be correct.
00265   // Not sure whether we could get (significant) negative eigenvalues, but let's check.
00266   for (unsigned i=0; i<evals.size(); ++i)
00267   {
00268     if (evals[i]<-1e-12) // tolerance?
00269       throw mbl_exception_abort("mbl_lda::build(): found negative eigenvalue(s)");
00270   }
00271   evals.flip();
00272   EVecs.fliplr();
00273 
00274   // Log some information that might be helpful for debugging
00275   if (logger().level()>=mbl_logger::DEBUG)
00276   {
00277     MBL_LOG(DEBUG, logger(), "eigen decomp in sorted order:");
00278     unsigned nvec = EVecs.cols();
00279     for (unsigned i=0; i<nvec; ++i)
00280       MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00281               << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00282     for (unsigned i=0; i<nvec; ++i)
00283       MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00284   }
00285 
00286   // Record n_classes-1 vector basis
00287   int m = EVecs.rows();
00288   int t = n_used_classes-1;
00289   if (t>m) t=m;
00290 
00291   // Copy first t eigenvectors to basis_
00292   basis_.set_size(m,t);
00293   double **E = EVecs.data_array();
00294   double **b = basis_.data_array();
00295   vcl_size_t bytes_per_row = t * sizeof(double);
00296   for (int i=0;i<m;++i)
00297   {
00298     vcl_memcpy(b[i],E[i],bytes_per_row);
00299   }
00300 
00301   // Normalize the basis vectors
00302   MBL_LOG(DEBUG, logger(), "basis matrix before normalization:");
00303   basis_.print(logger().log(mbl_logger::DEBUG));
00304   //MBL_LOG(NOTICE, logger(), "normalization turned OFF");
00305   basis_.normalize_columns();
00306   MBL_LOG(DEBUG, logger(), "basis matrix after normalization:");
00307   basis_.print(logger().log(mbl_logger::DEBUG));
00308   logger().log(mbl_logger::DEBUG) << vcl_flush;
00309 
00310   // Copy first t eigenvalues
00311   evals_.set_size(t);
00312   for (int i=0;i<t;++i)
00313     evals_[i] = evals[i];
00314 
00315   // Compute projection of mean into d space
00316   d_m_mean_.set_size(t);
00317   mbl_matxvec_prod_vm(mean_class_mean_,basis_,d_m_mean_);
00318 
00319   // Project each mean into d-space
00320   d_mean_.resize(n_size);
00321   for (int i=0;i<n_size;++i)
00322     if (n_samples_[i]>0)
00323       x_to_d(d_mean_[i],mean_[i]);
00324 }
00325 
00326 
00327 //=======================================================================
00328 //: Perform LDA on data
00329 void mbl_lda::build(const vnl_vector<double>* v, const int* label, int n)
00330 {
00331   build(v,label,n,vnl_matrix<double>(),true);
00332 }
00333 
00334 //=======================================================================
00335 //: Perform LDA on data
00336 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label)
00337 {
00338   build(v,&label.front(),label.size(),vnl_matrix<double>(),true);
00339 }
00340 
00341 //=======================================================================
00342 //: Perform LDA on data
00343 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label,
00344                     const vnl_matrix<double>& wS)
00345 {
00346   build(v,&label.front(),label.size(),wS,false);
00347 }
00348 
00349 //=======================================================================
00350 //: Perform LDA on data
00351 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label)
00352 {
00353   assert(v.size()==label.size());
00354   build(&v.front(),&label.front(),label.size(),vnl_matrix<double>(),true);
00355 }
00356 
00357 //=======================================================================
00358 //: Perform LDA on data
00359 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label,
00360                     const vnl_matrix<double>& wS)
00361 {
00362   assert(v.size()==label.size());
00363   build(&v.front(),&label.front(),label.size(),wS,false);
00364 }
00365 
00366 //=======================================================================
00367 //: Perform LDA on data
00368 //  Columns of M form example vectors
00369 //  i'th column belongs to class label[i]
00370 //  Note: label([1..n]) not label([0..n-1])
00371 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label)
00372 {
00373   unsigned int n_egs = M.columns();
00374   assert(n_egs==label.size());
00375   //  assert(label.lo()==1);
00376   vcl_vector<vnl_vector<double> > v(n_egs);
00377   for (unsigned int i=0;i<n_egs;++i)
00378   {
00379     v[i] = M.get_column(i);
00380   }
00381   build(&v.front(),&label.front(),n_egs,vnl_matrix<double>(),true);
00382 }
00383 
00384 //=======================================================================
00385 //: Perform LDA on data
00386 //  Columns of M form example vectors
00387 //  i'th column belongs to class label[i]
00388 //  Note: label([1..n]) not label([0..n-1])
00389 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label,
00390                     const vnl_matrix<double>& wS)
00391 {
00392   unsigned int n_egs = M.columns();
00393   assert(n_egs==label.size());
00394   //  assert(label.lo()==1);
00395   vcl_vector<vnl_vector<double> > v(n_egs);
00396   for (unsigned int i=0;i<n_egs;++i)
00397   {
00398     v[i] = M.get_column(i);
00399   }
00400   build(&v.front(),&label.front(),n_egs,wS,false);
00401 }
00402 
00403 
00404 //=======================================================================
00405 //: Project x into discriminant space
00406 void mbl_lda::x_to_d(vnl_vector<double>& d, const vnl_vector<double>& x) const
00407 {
00408   d.set_size(d_m_mean_.size());
00409   mbl_matxvec_prod_vm(x,basis_,d); // d = x' * M
00410   d-=d_m_mean_;
00411 }
00412 
00413 //=======================================================================
00414 //: Project d from discriminant space into original space
00415 void mbl_lda::d_to_x(vnl_vector<double>& x, const vnl_vector<double>& d) const
00416 {
00417   mbl_matxvec_prod_mv(basis_,d,x); // x = M * d
00418   x+=mean_class_mean_;
00419 }
00420 
00421 //=======================================================================
00422 
00423 short mbl_lda::version_no() const
00424 {
00425   return 1;
00426 }
00427 
00428 //=======================================================================
00429 
00430 vcl_string mbl_lda::is_a() const
00431 {
00432   return vcl_string("mbl_lda");
00433 }
00434 
00435 bool mbl_lda::is_class(vcl_string const& s) const
00436 {
00437   return s==is_a();
00438 }
00439 
00440 //=======================================================================
00441 
00442 void mbl_lda::print_summary(vcl_ostream& os) const
00443 {
00444   int n_classes= n_samples_.size();
00445   os << "n_classes= "<<n_classes<<'\n';
00446   for (int i=0; i<n_classes; ++i)
00447   {
00448     os <<"n_samples_["<<i<<"]= "<<n_samples_[i]<<'\n'
00449        <<"mean_["<<i<<"]= "<<mean_[i]<<'\n'
00450        <<"d_mean_["<<i<<"]= "<<d_mean_[i]<<'\n';
00451   }
00452 
00453   os << "withinS_= "<<withinS_<<'\n'
00454      << "betweenS_= "<<betweenS_<<'\n'
00455      << "basis_= "<<basis_<<'\n'
00456      << "evals_= "<<evals_<<'\n'
00457      << "d_m_mean_= "<<d_m_mean_<<'\n';
00458 }
00459 
00460 //=======================================================================
00461 
00462 void mbl_lda::b_write(vsl_b_ostream& bfs) const
00463 {
00464   vsl_b_write(bfs,version_no());
00465   vsl_b_write(bfs,mean_);
00466   vsl_b_write(bfs,d_mean_);
00467   vsl_b_write(bfs,mean_class_mean_);
00468   vsl_b_write(bfs,n_samples_);
00469   vsl_b_write(bfs,withinS_);
00470   vsl_b_write(bfs,betweenS_);
00471   vsl_b_write(bfs,basis_);
00472   vsl_b_write(bfs,evals_);
00473   vsl_b_write(bfs,d_m_mean_);
00474 }
00475 
00476 //=======================================================================
00477 
00478 void mbl_lda::b_read(vsl_b_istream& bfs)
00479 {
00480   if (!bfs) return;
00481 
00482   short version;
00483   vsl_b_read(bfs,version);
00484   switch (version)
00485   {
00486     case (1):
00487       vsl_b_read(bfs,mean_);
00488       vsl_b_read(bfs,d_mean_);
00489       vsl_b_read(bfs,mean_class_mean_);
00490       vsl_b_read(bfs,n_samples_);
00491       vsl_b_read(bfs,withinS_);
00492       vsl_b_read(bfs,betweenS_);
00493       vsl_b_read(bfs,basis_);
00494       vsl_b_read(bfs,evals_);
00495       vsl_b_read(bfs,d_m_mean_);
00496       break;
00497     default:
00498       // CHECK FUNCTION SIGNATURE IS CORRECT
00499       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_lda &)\n"
00500                << "           Unknown version number "<< version << vcl_endl;
00501       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00502       return;
00503   }
00504 }
00505 
00506 //=======================================================================
00507 
00508 void vsl_b_write(vsl_b_ostream& bfs, const mbl_lda& b)
00509 {
00510   b.b_write(bfs);
00511 }
00512 
00513 //=======================================================================
00514 
00515 void vsl_b_read(vsl_b_istream& bfs, mbl_lda& b)
00516 {
00517   b.b_read(bfs);
00518 }
00519 
00520 //=======================================================================
00521 
00522 vcl_ostream& operator<<(vcl_ostream& os,const mbl_lda& b)
00523 {
00524   os << b.is_a() << ": ";
00525   vsl_indent_inc(os);
00526   b.print_summary(os);
00527   vsl_indent_dec(os);
00528   return os;
00529 }
00530 
00531 //=======================================================================
00532 void vsl_print_summary(vcl_ostream& os, const mbl_lda& b)
00533 {
00534   b.print_summary(os);
00535 }
00536 
00537