00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010
00011 #include "mbl_lda.h"
00012
00013 #include <vcl_algorithm.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cstddef.h>
00016 #include <vcl_cstring.h>
00017 #include <vsl/vsl_indent.h>
00018 #include <vsl/vsl_vector_io.h>
00019 #include <vsl/vsl_binary_io.h>
00020 #include <vnl/algo/vnl_svd.h>
00021 #include <vnl/algo/vnl_symmetric_eigensystem.h>
00022 #include <vnl/algo/vnl_generalized_eigensystem.h>
00023 #include <vnl/io/vnl_io_vector.h>
00024 #include <mbl/mbl_matxvec.h>
00025 #include <mbl/mbl_log.h>
00026 #include <mbl/mbl_exception.h>
00027
00028
00029
00030
00031
00032 static mbl_logger& logger()
00033 {
00034 static mbl_logger l("mul.mbl.lda");
00035 return l;
00036 }
00037
00038
00039
00040 mbl_lda::mbl_lda()
00041 {
00042 }
00043
00044
00045
00046 mbl_lda::~mbl_lda()
00047 {
00048 }
00049
00050
00051
00052
00053
00054 int mbl_lda::classify(const vnl_vector<double>& x) const
00055 {
00056 vnl_vector<double> d;
00057 x_to_d(d, x);
00058 int nc=n_classes();
00059 double min_d=(d-d_class_mean(0)).squared_magnitude();
00060 int min_i=0;
00061 for (int i=1; i<nc; ++i)
00062 {
00063 double dist=(d-d_class_mean(i)).squared_magnitude();
00064 if (dist<min_d ) { min_d= dist; min_i=i; }
00065 }
00066 return min_i;
00067 }
00068
00069
00070
00071
00072 bool mbl_lda::operator==(const mbl_lda& that) const
00073 {
00074 return mean_ == that.mean_ &&
00075 d_mean_ == that.d_mean_ &&
00076 mean_class_mean_ == that.mean_class_mean_ &&
00077 n_samples_ == that.n_samples_ &&
00078 withinS_ == that.withinS_ &&
00079 betweenS_ == that.betweenS_ &&
00080 basis_ == that.basis_ &&
00081 evals_ == that.evals_ &&
00082 d_m_mean_ == that.d_m_mean_;
00083 }
00084
00085
00086
00087 void mbl_lda::updateCovar(vnl_matrix<double>& S, const vnl_vector<double>& V)
00088 {
00089 unsigned int n = V.size();
00090 if (S.rows()!=n)
00091 {
00092 S.set_size(n,n);
00093 S.fill(0);
00094 }
00095
00096 double** s = S.data_array();
00097 const double* v = V.data_block();
00098 for (unsigned int i=0;i<n;++i)
00099 {
00100 double *row = s[i];
00101 double vi = v[i];
00102 for (unsigned int j=0;j<n;++j)
00103 row[j] += vi*v[j];
00104 }
00105 }
00106
00107
00108
00109
00110 int mbl_lda::nDistinctIDs(const int* id, const int n)
00111 {
00112 vcl_vector<int> dids;
00113 for (int i=0;i<n;++i)
00114 {
00115 if (vcl_find(dids.begin(), dids.end(), id[i])==dids.end())
00116 dids.push_back(id[i]);
00117 }
00118
00119 return dids.size();
00120 }
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130 void mbl_lda::build(const vnl_vector<double>* v, const int * label, int n,
00131 const vnl_matrix<double>& wS, bool compute_wS)
00132 {
00133
00134 int lo_i=label[0];
00135 int hi_i=-1;
00136 int n_valid = 0;
00137 for (int i=0;i<n;++i)
00138 {
00139 if (label[i]>=0)
00140 {
00141 if (label[i]<lo_i) lo_i=label[i];
00142 if (label[i]>hi_i) hi_i=label[i];
00143 n_valid++;
00144 }
00145 }
00146
00147
00148
00149
00150 int n_classes = nDistinctIDs(label,n);
00151 MBL_LOG(INFO, logger(), "There are " <<n_classes << " classes to build LDA space");
00152 MBL_LOG(INFO, logger(), "Max label index is " << hi_i);
00153 MBL_LOG(INFO, logger(), "Min label index is " << lo_i);
00154
00155 int n_size=hi_i+1;
00156 mean_.resize(n_size);
00157 n_samples_.resize(n_size);
00158 for (int i=0;i<n_size;++i)
00159 n_samples_[i]=0;
00160
00161 for (int i=0;i<n;++i)
00162 {
00163 int l = label[i];
00164 if (l<0) continue;
00165 if (mean_[l].size()==0)
00166 {
00167 mean_[l] = v[i];
00168 n_samples_[l] = 1;
00169 }
00170 else
00171 {
00172 mean_[l] += v[i];
00173 n_samples_[l] += 1;
00174 }
00175 }
00176
00177 int n_used_classes = 0;
00178 for (int i=0;i<n_size;++i)
00179 {
00180 if (n_samples_[i]>0)
00181 {
00182 mean_[i]/=n_samples_[i];
00183 if (i==lo_i) mean_class_mean_ = mean_[i];
00184 else mean_class_mean_ += mean_[i];
00185 n_used_classes++;
00186 }
00187 }
00188 MBL_LOG(INFO, logger(), "Number of used classes: " << n_used_classes);
00189
00190 mean_class_mean_/=n_used_classes;
00191
00192
00193
00194 betweenS_.set_size(0,0);
00195
00196 for (int i=0;i<n_size;++i)
00197 {
00198 if (n_samples_[i]>0)
00199 updateCovar(betweenS_,mean_[i] - mean_class_mean_);
00200 }
00201
00202 betweenS_/=n_used_classes;
00203
00204 if (compute_wS)
00205 {
00206 withinS_.set_size(0,0);
00207
00208 int n_used=0;
00209 for (int i=0;i<n;++i)
00210 {
00211 int l=label[i];
00212 if (l>=0 && n_samples_[l]>1)
00213 {
00214 updateCovar(withinS_,v[i]-mean_[l]);
00215 n_used++;
00216 }
00217 }
00218 withinS_/=n_used;
00219 }
00220 else
00221 withinS_ = wS;
00222
00223 #if 0
00224 vnl_matrix<double> wS_inv;
00225
00226 vnl_svd<double> wS_svd(withinS_, -1.0e-10);
00227
00228 wS_inv = wS_svd.inverse();
00229
00230 vnl_matrix<double> B=withinS_*wS_inv;
00231 vcl_cout<<B<<vcl_endl;
00232
00233 vnl_matrix<double> A = wS_inv* betweenS_;
00234
00235
00236 vnl_matrix<double> EVecs(A.rows(), A.columns());
00237 vnl_vector<double> evals(A.columns());
00238
00239
00240
00241 vnl_symmetric_eigensystem_compute(A, EVecs, evals);
00242 #endif // 0
00243
00244 vnl_generalized_eigensystem gen_eigs(betweenS_,withinS_);
00245 vnl_matrix<double> EVecs= gen_eigs.V;
00246 vnl_vector<double> evals= gen_eigs.D.diagonal();
00247
00248
00249 if (logger().level()>=mbl_logger::DEBUG)
00250 {
00251 MBL_LOG(DEBUG, logger(), "eigen decomp in original order:");
00252 unsigned nvec = EVecs.cols();
00253 for (unsigned i=0; i<nvec; ++i)
00254 MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00255 << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00256 for (unsigned i=0; i<nvec; ++i)
00257 MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00258 }
00259
00260
00261
00262
00263
00264
00265
00266 for (unsigned i=0; i<evals.size(); ++i)
00267 {
00268 if (evals[i]<-1e-12)
00269 throw mbl_exception_abort("mbl_lda::build(): found negative eigenvalue(s)");
00270 }
00271 evals.flip();
00272 EVecs.fliplr();
00273
00274
00275 if (logger().level()>=mbl_logger::DEBUG)
00276 {
00277 MBL_LOG(DEBUG, logger(), "eigen decomp in sorted order:");
00278 unsigned nvec = EVecs.cols();
00279 for (unsigned i=0; i<nvec; ++i)
00280 MBL_LOG(DEBUG, logger(), "Col " << i << ": " << EVecs.get_column(i)
00281 << "(magn: " << EVecs.get_column(i).magnitude() << ')');
00282 for (unsigned i=0; i<nvec; ++i)
00283 MBL_LOG(DEBUG, logger(), "eval " << i << ": " << evals[i]);
00284 }
00285
00286
00287 int m = EVecs.rows();
00288 int t = n_used_classes-1;
00289 if (t>m) t=m;
00290
00291
00292 basis_.set_size(m,t);
00293 double **E = EVecs.data_array();
00294 double **b = basis_.data_array();
00295 vcl_size_t bytes_per_row = t * sizeof(double);
00296 for (int i=0;i<m;++i)
00297 {
00298 vcl_memcpy(b[i],E[i],bytes_per_row);
00299 }
00300
00301
00302 MBL_LOG(DEBUG, logger(), "basis matrix before normalization:");
00303 basis_.print(logger().log(mbl_logger::DEBUG));
00304
00305 basis_.normalize_columns();
00306 MBL_LOG(DEBUG, logger(), "basis matrix after normalization:");
00307 basis_.print(logger().log(mbl_logger::DEBUG));
00308 logger().log(mbl_logger::DEBUG) << vcl_flush;
00309
00310
00311 evals_.set_size(t);
00312 for (int i=0;i<t;++i)
00313 evals_[i] = evals[i];
00314
00315
00316 d_m_mean_.set_size(t);
00317 mbl_matxvec_prod_vm(mean_class_mean_,basis_,d_m_mean_);
00318
00319
00320 d_mean_.resize(n_size);
00321 for (int i=0;i<n_size;++i)
00322 if (n_samples_[i]>0)
00323 x_to_d(d_mean_[i],mean_[i]);
00324 }
00325
00326
00327
00328
00329 void mbl_lda::build(const vnl_vector<double>* v, const int* label, int n)
00330 {
00331 build(v,label,n,vnl_matrix<double>(),true);
00332 }
00333
00334
00335
00336 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label)
00337 {
00338 build(v,&label.front(),label.size(),vnl_matrix<double>(),true);
00339 }
00340
00341
00342
00343 void mbl_lda::build(const vnl_vector<double>* v, const vcl_vector<int>& label,
00344 const vnl_matrix<double>& wS)
00345 {
00346 build(v,&label.front(),label.size(),wS,false);
00347 }
00348
00349
00350
00351 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label)
00352 {
00353 assert(v.size()==label.size());
00354 build(&v.front(),&label.front(),label.size(),vnl_matrix<double>(),true);
00355 }
00356
00357
00358
00359 void mbl_lda::build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label,
00360 const vnl_matrix<double>& wS)
00361 {
00362 assert(v.size()==label.size());
00363 build(&v.front(),&label.front(),label.size(),wS,false);
00364 }
00365
00366
00367
00368
00369
00370
00371 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label)
00372 {
00373 unsigned int n_egs = M.columns();
00374 assert(n_egs==label.size());
00375
00376 vcl_vector<vnl_vector<double> > v(n_egs);
00377 for (unsigned int i=0;i<n_egs;++i)
00378 {
00379 v[i] = M.get_column(i);
00380 }
00381 build(&v.front(),&label.front(),n_egs,vnl_matrix<double>(),true);
00382 }
00383
00384
00385
00386
00387
00388
00389 void mbl_lda::build(const vnl_matrix<double>& M, const vcl_vector<int>& label,
00390 const vnl_matrix<double>& wS)
00391 {
00392 unsigned int n_egs = M.columns();
00393 assert(n_egs==label.size());
00394
00395 vcl_vector<vnl_vector<double> > v(n_egs);
00396 for (unsigned int i=0;i<n_egs;++i)
00397 {
00398 v[i] = M.get_column(i);
00399 }
00400 build(&v.front(),&label.front(),n_egs,wS,false);
00401 }
00402
00403
00404
00405
00406 void mbl_lda::x_to_d(vnl_vector<double>& d, const vnl_vector<double>& x) const
00407 {
00408 d.set_size(d_m_mean_.size());
00409 mbl_matxvec_prod_vm(x,basis_,d);
00410 d-=d_m_mean_;
00411 }
00412
00413
00414
00415 void mbl_lda::d_to_x(vnl_vector<double>& x, const vnl_vector<double>& d) const
00416 {
00417 mbl_matxvec_prod_mv(basis_,d,x);
00418 x+=mean_class_mean_;
00419 }
00420
00421
00422
00423 short mbl_lda::version_no() const
00424 {
00425 return 1;
00426 }
00427
00428
00429
00430 vcl_string mbl_lda::is_a() const
00431 {
00432 return vcl_string("mbl_lda");
00433 }
00434
00435 bool mbl_lda::is_class(vcl_string const& s) const
00436 {
00437 return s==is_a();
00438 }
00439
00440
00441
00442 void mbl_lda::print_summary(vcl_ostream& os) const
00443 {
00444 int n_classes= n_samples_.size();
00445 os << "n_classes= "<<n_classes<<'\n';
00446 for (int i=0; i<n_classes; ++i)
00447 {
00448 os <<"n_samples_["<<i<<"]= "<<n_samples_[i]<<'\n'
00449 <<"mean_["<<i<<"]= "<<mean_[i]<<'\n'
00450 <<"d_mean_["<<i<<"]= "<<d_mean_[i]<<'\n';
00451 }
00452
00453 os << "withinS_= "<<withinS_<<'\n'
00454 << "betweenS_= "<<betweenS_<<'\n'
00455 << "basis_= "<<basis_<<'\n'
00456 << "evals_= "<<evals_<<'\n'
00457 << "d_m_mean_= "<<d_m_mean_<<'\n';
00458 }
00459
00460
00461
00462 void mbl_lda::b_write(vsl_b_ostream& bfs) const
00463 {
00464 vsl_b_write(bfs,version_no());
00465 vsl_b_write(bfs,mean_);
00466 vsl_b_write(bfs,d_mean_);
00467 vsl_b_write(bfs,mean_class_mean_);
00468 vsl_b_write(bfs,n_samples_);
00469 vsl_b_write(bfs,withinS_);
00470 vsl_b_write(bfs,betweenS_);
00471 vsl_b_write(bfs,basis_);
00472 vsl_b_write(bfs,evals_);
00473 vsl_b_write(bfs,d_m_mean_);
00474 }
00475
00476
00477
00478 void mbl_lda::b_read(vsl_b_istream& bfs)
00479 {
00480 if (!bfs) return;
00481
00482 short version;
00483 vsl_b_read(bfs,version);
00484 switch (version)
00485 {
00486 case (1):
00487 vsl_b_read(bfs,mean_);
00488 vsl_b_read(bfs,d_mean_);
00489 vsl_b_read(bfs,mean_class_mean_);
00490 vsl_b_read(bfs,n_samples_);
00491 vsl_b_read(bfs,withinS_);
00492 vsl_b_read(bfs,betweenS_);
00493 vsl_b_read(bfs,basis_);
00494 vsl_b_read(bfs,evals_);
00495 vsl_b_read(bfs,d_m_mean_);
00496 break;
00497 default:
00498
00499 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_lda &)\n"
00500 << " Unknown version number "<< version << vcl_endl;
00501 bfs.is().clear(vcl_ios::badbit);
00502 return;
00503 }
00504 }
00505
00506
00507
00508 void vsl_b_write(vsl_b_ostream& bfs, const mbl_lda& b)
00509 {
00510 b.b_write(bfs);
00511 }
00512
00513
00514
00515 void vsl_b_read(vsl_b_istream& bfs, mbl_lda& b)
00516 {
00517 b.b_read(bfs);
00518 }
00519
00520
00521
00522 vcl_ostream& operator<<(vcl_ostream& os,const mbl_lda& b)
00523 {
00524 os << b.is_a() << ": ";
00525 vsl_indent_inc(os);
00526 b.print_summary(os);
00527 vsl_indent_dec(os);
00528 return os;
00529 }
00530
00531
00532 void vsl_print_summary(vcl_ostream& os, const mbl_lda& b)
00533 {
00534 b.print_summary(os);
00535 }
00536
00537