contrib/mul/mbl/mbl_lda.h
Go to the documentation of this file.
00001 // This is mul/mbl/mbl_lda.h
00002 #ifndef mbl_lda_h_
00003 #define mbl_lda_h_
00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00005 #pragma interface
00006 #endif
00007 //:
00008 // \file
00009 // \brief  Class to perform linear discriminant analysis
00010 // \author Tim Cootes
00011 //         Converted to VXL by Gavin Wheeler
00012 
00013 #include <vcl_string.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_iosfwd.h>
00016 
00017 #include <vnl/vnl_vector.h>
00018 #include <vnl/vnl_matrix.h>
00019 #include <vnl/io/vnl_io_matrix.h>
00020 
00021 //=======================================================================
00022 //: Class to perform linear discriminant analysis
00023 class mbl_lda
00024 {
00025  private:
00026 
00027   vcl_vector<vnl_vector<double> > mean_;
00028   vcl_vector<vnl_vector<double> > d_mean_;
00029   vnl_vector<double> mean_class_mean_;
00030   vcl_vector<int> n_samples_;
00031   vnl_matrix<double> withinS_;
00032   vnl_matrix<double> betweenS_;
00033   vnl_matrix<double> basis_;
00034   vnl_vector<double> evals_;
00035   vnl_vector<double> d_m_mean_;
00036 
00037   void updateCovar(vnl_matrix<double>& S, const vnl_vector<double>& v);
00038 
00039   //: Perform LDA on data
00040   // Classes must be labeled from 0..n-1
00041   // \param label  Array [0..n-1] of integer indices. label[i] gives class of v[i]
00042   // \param n      Size of label and of v
00043   // \param v  Set of vectors [0..n-1]
00044   // \param wS  Within class covariance to use if compute_wS false
00045   // \param compute_wS  This boolean parameter determines whether to use wS
00046   void build(const vnl_vector<double>* v, const int* label, int n,
00047              const vnl_matrix<double>& wS, bool compute_wS);
00048 
00049  public:
00050 
00051   //: Dflt ctor
00052   mbl_lda();
00053 
00054   //: Destructor
00055   virtual ~mbl_lda();
00056 
00057    //: Comparison
00058   bool operator==
00059     (const mbl_lda& that) const;
00060 
00061   //: Classify a new data point
00062   // projects into discriminant space and picks closest mean class vector
00063   int classify(const vnl_vector<double>& x) const;
00064 
00065   //: Perform LDA on data
00066   // \param n  Number of examples
00067   // \param label  integer indices
00068   // \param v  Set of vectors [0..n-1]
00069   //
00070   // - label[i] gives class of v[i]
00071   // - If label[i]<0 the class is assumed to be unknown
00072   //   and example i is ignored
00073   // - Classes must be labeled from 0..n-1
00074   void build(const vnl_vector<double>* v, const int* label, int n);
00075 
00076   //: Perform LDA on data
00077   // \param label  Array [0..n-1] of integers indices
00078   // \param v  Set of vectors [0..n-1]
00079   //
00080   // - label[i] gives class of v[i]
00081   // - If label[i]<0 the class is assumed to be unknown
00082   //   and example i is ignored
00083   // - Classes must be labeled from 0..n-1
00084   void build(const vnl_vector<double>* v, const vcl_vector<int>& label);
00085 
00086   //: Perform LDA on data
00087   // \param label  Array [0..n-1] of integers indices
00088   // \param v  Set of vectors [0..n-1]
00089   // \param wS  Within class covariance to use
00090   //
00091   // - label[i] gives class of v[i]
00092   // - If label[i]<0 the class is assumed to be unknown
00093   //   and example i is ignored
00094   // - Classes must be labeled from 0..n-1
00095   void build(const vnl_vector<double>* v, const vcl_vector<int>& label,
00096              const vnl_matrix<double>& wS);
00097 
00098   //: Perform LDA on data
00099   // \param label  Array [0..n-1] of integers indices
00100   // \param v  Set of vectors [0..n-1]
00101   //
00102   // - label[i] gives class of v[i]
00103   // - If label[i]<0 the class is assumed to be unknown
00104   //   and example i is ignored
00105   // - Classes must be labeled from 0..n-1
00106   void build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label);
00107 
00108   //: Perform LDA on data
00109   // \param label  Array [0..n-1] of integers indices
00110   // \param v  Set of vectors [0..n-1]
00111   // \param wS  Within class covariance to use
00112   //
00113   // - label[i] gives class of v[i]
00114   // - Classes must be labeled from 0..n-1
00115   // - If label[i]<0 the class is assumed to be unknown and example i is ignored
00116   void build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label,
00117              const vnl_matrix<double>& wS);
00118 
00119   //: Perform LDA on data
00120   // - Columns of M form example vectors
00121   // - i'th column belongs to class label(i)
00122   // - Note: label([1..n]) not label([0..n-1])
00123   // - If label[i]<0 the class is assumed to be unknown
00124   //   and example i is ignored
00125   // - Note also that this is inefficient - it converts the
00126   //   matrix to an array and calls build(v,label)
00127   void build(const vnl_matrix<double>& M, const vcl_vector<int>& label);
00128 
00129   //: Perform LDA on data
00130   // - Columns of M form example vectors
00131   // - i'th column belongs to class label(i)
00132   // - Note: label([1..n]) not label([0..n-1])
00133   // - Note also that this is inefficient - it converts the
00134   //   matrix to an array and calls build(v,label)
00135   // - If label[i]<0 the class is assumed to be unknown
00136   //   and example i is ignored
00137   // \param M     The columns of this matrix for the example vectors
00138   // \param label The vector of class labels corresponding to these examples
00139   // \param wS    Within class covariance to use
00140   void build(const vnl_matrix<double>& M, const vcl_vector<int>& label,
00141              const vnl_matrix<double>& wS);
00142 
00143   //: Number of classes
00144   int n_classes() const { return mean_.size(); }
00145 
00146   //: Number of examples of each class
00147   int n_samples(int i) const { return n_samples_[i]; }
00148 
00149   //: Mean vector for i'th class in original space
00150   const vnl_vector<double>& class_mean(int i) const { return mean_[i]; }
00151 
00152   //: Mean vector for i'th class in discriminant space
00153   const vnl_vector<double>& d_class_mean(int i) const { return d_mean_[i]; }
00154 
00155   //: Mean of means for each class
00156   const vnl_vector<double>& mean_class_mean() const { return mean_class_mean_; }
00157 
00158   //: Within class covariance matrix
00159   const vnl_matrix<double>& within_covar() const { return withinS_; }
00160 
00161   //: Between class covariance matrix
00162   const vnl_matrix<double>& between_covar() const { return betweenS_; }
00163 
00164   //: Basis for discriminant space
00165   const vnl_matrix<double>& basis() const { return basis_; }
00166 
00167   //: Eigenvalues associated with each basis vector
00168   const vnl_vector<double>& basis_e_vals() const { return evals_; }
00169 
00170   //: Project x into discriminant space
00171   void x_to_d(vnl_vector<double>& d, const vnl_vector<double>& x) const;
00172 
00173   //: Project d from discriminant space into original space
00174   void d_to_x(vnl_vector<double>& x, const vnl_vector<double>& d) const;
00175 
00176   //: find out how many id in the label vector
00177   int nDistinctIDs(const int* id, const int n);
00178 
00179   //: Version number for I/O
00180   short version_no() const;
00181 
00182   //: Name of the class
00183   virtual vcl_string is_a() const;
00184 
00185   //: True if this is (or is derived from) class named s
00186   virtual bool is_class(vcl_string const& s) const;
00187 
00188   //: Print class to os
00189   virtual void print_summary(vcl_ostream& os) const;
00190 
00191   //: Save class to binary file stream
00192   virtual void b_write(vsl_b_ostream& bfs) const;
00193 
00194   //: Load class from binary file stream
00195   virtual void b_read(vsl_b_istream& bfs);
00196 };
00197 
00198 
00199 //: Binary file stream output operator for class reference
00200 void vsl_b_write(vsl_b_ostream& bfs, const mbl_lda& b);
00201 
00202 //: Binary file stream input operator for class reference
00203 void vsl_b_read(vsl_b_istream& bfs, mbl_lda& b);
00204 
00205 //: Print summary for class reference
00206 void vsl_print_summary(vcl_ostream& os, const mbl_lda& b);
00207 
00208 //: Stream output operator for class reference
00209 vcl_ostream& operator<<(vcl_ostream& os,const mbl_lda& b);
00210 
00211 #endif // mbl_lda_h_