00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "clsfy_adaboost_trainer.h"
00021
00022 #include <vcl_iostream.h>
00023 #include <vsl/vsl_indent.h>
00024 #include <vcl_cmath.h>
00025 #include <vcl_cassert.h>
00026
00027
00028
00029 clsfy_adaboost_trainer::clsfy_adaboost_trainer()
00030 {
00031 }
00032
00033
00034
00035 clsfy_adaboost_trainer::~clsfy_adaboost_trainer()
00036 {
00037 }
00038
00039
00040
00041 void clsfy_adaboost_trainer::clsfy_get_elements(
00042 vnl_vector<double>& v,
00043 mbl_data_wrapper<vnl_vector<double> >& data,
00044 int j)
00045 {
00046 unsigned long n = data.size();
00047 v.set_size(n);
00048 data.reset();
00049 for (unsigned long i=0;i<n;++i)
00050 {
00051 v[i] = data.current()[j];
00052 data.next();
00053 }
00054 }
00055
00056
00057
00058 void clsfy_adaboost_trainer::clsfy_update_weights_weak(
00059 vnl_vector<double> &wts,
00060 const vnl_vector<double>& data,
00061 clsfy_classifier_1d& classifier,
00062 int class_number,
00063 double beta)
00064 {
00065 assert(class_number >= 0);
00066 unsigned int n = wts.size();
00067 for (unsigned int i=0;i<n;++i)
00068 if (classifier.classify(data[i])==(unsigned)class_number) wts[i]*=beta;
00069 }
00070
00071
00072
00073
00074
00075
00076
00077 void clsfy_adaboost_trainer::build_strong_classifier(
00078 clsfy_simple_adaboost& strong_classifier,
00079 int max_n_clfrs,
00080 clsfy_builder_1d& builder,
00081 mbl_data_wrapper<vnl_vector<double> >& egs0,
00082 mbl_data_wrapper<vnl_vector<double> >& egs1)
00083 {
00084
00085 strong_classifier.clear();
00086
00087
00088 clsfy_classifier_1d* c1d = builder.new_classifier();
00089 clsfy_classifier_1d* best_c1d= builder.new_classifier();
00090
00091 unsigned long n0 = egs0.size();
00092 unsigned long n1 = egs1.size();
00093 int n=max_n_clfrs;
00094
00095
00096 unsigned int d = egs0.current().size();
00097 strong_classifier.set_n_dims(d);
00098
00099
00100 vnl_vector<double> wts0(n0,0.5/n0);
00101 vnl_vector<double> wts1(n1,0.5/n1);
00102
00103 vnl_vector<double> egs0_1d, egs1_1d;
00104
00105 for (int i=0;i<n;++i)
00106 {
00107 vcl_cout<<"adaboost training round = "<<i<<'\n';
00108
00109
00110
00111 int best_j=-1;
00112 double min_error= 100000;
00113 for (unsigned int j=0;j<d;++j)
00114 {
00115
00116 clsfy_get_elements(egs0_1d,egs0,j);
00117 clsfy_get_elements(egs1_1d,egs1,j);
00118
00119 double error = builder.build(*c1d,egs0_1d,wts0,egs1_1d,wts1);
00120
00121 if (j==0 || error<min_error)
00122 {
00123 min_error = error;
00124 delete best_c1d;
00125 best_c1d= c1d->clone();
00126 best_j = j;
00127 }
00128 }
00129
00130 vcl_cout<<"best_j= "<<best_j<<'\n'
00131 <<"min_error= "<<min_error<<'\n';
00132
00133 if (min_error<1e-10)
00134 {
00135 vcl_cout<<"min_error<1e-10 !!!\n";
00136 double alpha = vcl_log(2.0*(n0+n1));
00137 strong_classifier.add_classifier( best_c1d, alpha, best_j);
00138
00139
00140 delete c1d;
00141 delete best_c1d;
00142 return;
00143 }
00144
00145
00146 if (0.5-min_error<1e-10)
00147 {
00148 vcl_cout<<"min_error => 0.5 !!!\n";
00149
00150
00151 delete c1d;
00152 delete best_c1d;
00153 return;
00154 }
00155
00156 double beta = min_error/(1.0-min_error);
00157 double alpha = -1.0*vcl_log(beta);
00158 strong_classifier.add_classifier( best_c1d, alpha, best_j);
00159
00160 if (i<(n-1))
00161 {
00162
00163 clsfy_get_elements(egs0_1d,egs0,best_j);
00164 clsfy_get_elements(egs1_1d,egs1,best_j);
00165
00166 clsfy_update_weights_weak(wts0,egs0_1d,*best_c1d,0,beta);
00167 clsfy_update_weights_weak(wts1,egs1_1d,*best_c1d,1,beta);
00168
00169
00170 double w_sum = wts0.mean()*n0 + wts1.mean()*n1;
00171 wts0/=w_sum;
00172 wts1/=w_sum;
00173 }
00174 }
00175
00176 delete c1d;
00177 delete best_c1d;
00178 }
00179
00180
00181
00182 short clsfy_adaboost_trainer::version_no() const
00183 {
00184 return 1;
00185 }
00186
00187
00188
00189 vcl_string clsfy_adaboost_trainer::is_a() const
00190 {
00191 return vcl_string("clsfy_adaboost_trainer");
00192 }
00193
00194 bool clsfy_adaboost_trainer::is_class(vcl_string const& s) const
00195 {
00196 return s == clsfy_adaboost_trainer::is_a();
00197 }
00198
00199
00200
00201 #if 0
00202
00203
00204 clsfy_adaboost_trainer::clsfy_adaboost_trainer(const clsfy_adaboost_trainer& new_b):
00205 data_ptr_(0)
00206 {
00207 *this = new_b;
00208 }
00209
00210
00211
00212
00213 clsfy_adaboost_trainer& clsfy_adaboost_trainer::operator=(const clsfy_adaboost_trainer& new_b)
00214 {
00215 if (&new_b==this) return *this;
00216
00217
00218 delete data_ptr_; data_ptr_=0;
00219
00220 if (new_b.data_ptr_)
00221 data_ptr_ = new_b.data_ptr_->clone();
00222
00223
00224 data_ = new_b.data_;
00225
00226 return *this;
00227 }
00228
00229 #endif // 0
00230
00231
00232
00233
00234 void clsfy_adaboost_trainer::print_summary(vcl_ostream& ) const
00235 {
00236
00237 vcl_cerr << "clsfy_adaboost_trainer::print_summary() NYI\n";
00238 }
00239
00240
00241
00242
00243 void clsfy_adaboost_trainer::b_write(vsl_b_ostream& ) const
00244 {
00245
00246
00247 vcl_cerr << "clsfy_adaboost_trainer::b_write() NYI\n";
00248 }
00249
00250
00251
00252
00253 void clsfy_adaboost_trainer::b_read(vsl_b_istream& )
00254 {
00255 vcl_cerr << "clsfy_adaboost_trainer::b_read() NYI\n";
00256 #if 0
00257 if (!bfs) return;
00258
00259 short version;
00260 vsl_b_read(bfs,version);
00261 switch (version)
00262 {
00263 case (1):
00264 vsl_b_read(bfs,data_);
00265 break;
00266 default:
00267 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_adaboost_trainer&)\n"
00268 << " Unknown version number "<< version << '\n';
00269 bfs.is().clear(vcl_ios::badbit);
00270 return;
00271 }
00272 #endif
00273 }
00274
00275
00276
00277 void vsl_b_write(vsl_b_ostream& bfs, const clsfy_adaboost_trainer& b)
00278 {
00279 b.b_write(bfs);
00280 }
00281
00282
00283
00284 void vsl_b_read(vsl_b_istream& bfs, clsfy_adaboost_trainer& b)
00285 {
00286 b.b_read(bfs);
00287 }
00288
00289
00290
00291 void vsl_print_summary(vcl_ostream& os,const clsfy_adaboost_trainer& b)
00292 {
00293 os << b.is_a() << ": ";
00294 vsl_indent_inc(os);
00295 b.print_summary(os);
00296 vsl_indent_dec(os);
00297 }
00298
00299
00300
00301 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_adaboost_trainer& b)
00302 {
00303 vsl_print_summary(os,b);
00304 return os;
00305 }