Go to the documentation of this file.00001
00002 #ifndef mbl_lda_h_
00003 #define mbl_lda_h_
00004 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00005 #pragma interface
00006 #endif
00007
00008
00009
00010
00011
00012
00013 #include <vcl_string.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_iosfwd.h>
00016
00017 #include <vnl/vnl_vector.h>
00018 #include <vnl/vnl_matrix.h>
00019 #include <vnl/io/vnl_io_matrix.h>
00020
00021
00022
00023 class mbl_lda
00024 {
00025 private:
00026
00027 vcl_vector<vnl_vector<double> > mean_;
00028 vcl_vector<vnl_vector<double> > d_mean_;
00029 vnl_vector<double> mean_class_mean_;
00030 vcl_vector<int> n_samples_;
00031 vnl_matrix<double> withinS_;
00032 vnl_matrix<double> betweenS_;
00033 vnl_matrix<double> basis_;
00034 vnl_vector<double> evals_;
00035 vnl_vector<double> d_m_mean_;
00036
00037 void updateCovar(vnl_matrix<double>& S, const vnl_vector<double>& v);
00038
00039
00040
00041
00042
00043
00044
00045
00046 void build(const vnl_vector<double>* v, const int* label, int n,
00047 const vnl_matrix<double>& wS, bool compute_wS);
00048
00049 public:
00050
00051
00052 mbl_lda();
00053
00054
00055 virtual ~mbl_lda();
00056
00057
00058 bool operator==
00059 (const mbl_lda& that) const;
00060
00061
00062
00063 int classify(const vnl_vector<double>& x) const;
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074 void build(const vnl_vector<double>* v, const int* label, int n);
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084 void build(const vnl_vector<double>* v, const vcl_vector<int>& label);
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095 void build(const vnl_vector<double>* v, const vcl_vector<int>& label,
00096 const vnl_matrix<double>& wS);
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106 void build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label);
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116 void build(const vcl_vector<vnl_vector<double> >& v, const vcl_vector<int>& label,
00117 const vnl_matrix<double>& wS);
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127 void build(const vnl_matrix<double>& M, const vcl_vector<int>& label);
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140 void build(const vnl_matrix<double>& M, const vcl_vector<int>& label,
00141 const vnl_matrix<double>& wS);
00142
00143
00144 int n_classes() const { return mean_.size(); }
00145
00146
00147 int n_samples(int i) const { return n_samples_[i]; }
00148
00149
00150 const vnl_vector<double>& class_mean(int i) const { return mean_[i]; }
00151
00152
00153 const vnl_vector<double>& d_class_mean(int i) const { return d_mean_[i]; }
00154
00155
00156 const vnl_vector<double>& mean_class_mean() const { return mean_class_mean_; }
00157
00158
00159 const vnl_matrix<double>& within_covar() const { return withinS_; }
00160
00161
00162 const vnl_matrix<double>& between_covar() const { return betweenS_; }
00163
00164
00165 const vnl_matrix<double>& basis() const { return basis_; }
00166
00167
00168 const vnl_vector<double>& basis_e_vals() const { return evals_; }
00169
00170
00171 void x_to_d(vnl_vector<double>& d, const vnl_vector<double>& x) const;
00172
00173
00174 void d_to_x(vnl_vector<double>& x, const vnl_vector<double>& d) const;
00175
00176
00177 int nDistinctIDs(const int* id, const int n);
00178
00179
00180 short version_no() const;
00181
00182
00183 virtual vcl_string is_a() const;
00184
00185
00186 virtual bool is_class(vcl_string const& s) const;
00187
00188
00189 virtual void print_summary(vcl_ostream& os) const;
00190
00191
00192 virtual void b_write(vsl_b_ostream& bfs) const;
00193
00194
00195 virtual void b_read(vsl_b_istream& bfs);
00196 };
00197
00198
00199
00200 void vsl_b_write(vsl_b_ostream& bfs, const mbl_lda& b);
00201
00202
00203 void vsl_b_read(vsl_b_istream& bfs, mbl_lda& b);
00204
00205
00206 void vsl_print_summary(vcl_ostream& os, const mbl_lda& b);
00207
00208
00209 vcl_ostream& operator<<(vcl_ostream& os,const mbl_lda& b);
00210
00211 #endif // mbl_lda_h_