Go to the documentation of this file.00001
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include "clsfy_direct_boost_builder.h"
00019 #include "clsfy_direct_boost.h"
00020 #include "clsfy_builder_1d.h"
00021
00022 #include <vcl_iostream.h>
00023 #include <vcl_cstdlib.h>
00024 #include <vcl_algorithm.h>
00025 #include <vcl_cassert.h>
00026 #include <mbl/mbl_file_data_collector.h>
00027 #include <mbl/mbl_data_collector_list.h>
00028 #include <mbl/mbl_index_sort.h>
00029
00030
00031
00032 clsfy_direct_boost_builder::clsfy_direct_boost_builder()
00033 : save_data_to_disk_(false), bs_(-1), max_n_clfrs_(-1), weak_builder_(0)
00034 {
00035 }
00036
00037
00038
00039 clsfy_direct_boost_builder::~clsfy_direct_boost_builder()
00040 {
00041 }
00042
00043
00044
00045
00046 bool clsfy_direct_boost_builder::is_class(vcl_string const& s) const
00047 {
00048 return s == clsfy_direct_boost_builder::is_a() || clsfy_builder_base::is_class(s);
00049 }
00050
00051
00052
00053 vcl_string clsfy_direct_boost_builder::is_a() const
00054 {
00055 return vcl_string("clsfy_direct_boost_builder");
00056 }
00057
00058
00059
00060 double clsfy_direct_boost_builder::calc_prop_same(
00061 const vcl_vector<bool>& vec1,
00062 const vcl_vector<bool>& vec2) const
00063 {
00064 unsigned n = vec1.size();
00065 assert( n==vec2.size() );
00066 int sum= 0;
00067 for (unsigned i=0;i<n;++i)
00068 if (vec1[i]==vec2[i])
00069 ++sum;
00070
00071 return sum*1.0/n;
00072 }
00073
00074
00075
00076 double clsfy_direct_boost_builder::calc_threshold(
00077 clsfy_direct_boost& strong_classifier,
00078 mbl_data_wrapper<vnl_vector<double> >& inputs,
00079 const vcl_vector<unsigned>& outputs) const
00080 {
00081
00082 unsigned long n = inputs.size();
00083 vcl_vector<double> scores(n);
00084 inputs.reset();
00085 for (unsigned long i=0;i<n;++i)
00086 {
00087 scores[i]= strong_classifier.log_l( inputs.current() );
00088 inputs.next();
00089 }
00090
00091
00092 unsigned int tot_pos=0;
00093 for (unsigned long i=0;i<n;++i)
00094 if ( outputs[i] == 1 ) ++tot_pos;
00095
00096
00097 vcl_vector<int> index;
00098 mbl_index_sort(scores, index);
00099
00100 unsigned int n_pos=0;
00101 unsigned int n_neg=0;
00102 unsigned long min_error= n+1;
00103 double min_thresh= -1;
00104 for (unsigned long i=0;i<n;++i)
00105 {
00106 #ifdef DEBUG
00107 vcl_cout<<" scores[ index["<<i<<"] ] = "<< scores[ index[i] ]<<" ; "
00108 <<"outputs[ index["<<i<<"] ] = "<<outputs[ index[i] ]<<'\n';
00109 #endif
00110 if ( outputs[ index[i] ] == 0 ) ++n_neg;
00111 else if ( outputs[ index[i] ] == 1 ) ++n_pos;
00112 else
00113 {
00114 vcl_cout<<"ERROR: clsfy_direct_boost_basic_builder::calc_threshold()\n"
00115 <<"Unrecognised output value\n"
00116 <<"outputs[ index["<<i<<"] ] = outputs["<<index[i]<<"] = "
00117 <<outputs[index[i]]<<'\n';
00118 vcl_abort();
00119 }
00120
00121 #ifdef DEBUG
00122 vcl_cout<<"n = "<<n<<", n_pos= "<<n_pos<<", n_neg= "<<n_neg<<'\n';
00123 #endif
00124 unsigned int error= n_neg+(tot_pos-n_pos);
00125
00126 if ( error<= min_error )
00127 {
00128 min_error= error;
00129 min_thresh = scores[ index[i] ] + 0.001 ;
00130 #ifdef DEBUG
00131 vcl_cout<<"error= "<<error<<", min_thresh= "<<min_thresh<<'\n';
00132 #endif
00133 }
00134 }
00135
00136 assert( n_pos + n_neg == n );
00137 #ifdef DEBUG
00138 vcl_cout<<"min_error= "<<min_error<<", min_thresh= "<<min_thresh<<'\n';
00139 #endif
00140
00141 return min_thresh;
00142 }
00143
00144
00145
00146
00147
00148 double clsfy_direct_boost_builder::build(clsfy_classifier_base& model,
00149 mbl_data_wrapper<vnl_vector<double> >& inputs,
00150 unsigned ,
00151 const vcl_vector<unsigned> &outputs) const
00152 {
00153
00154
00155 assert( model.is_class("clsfy_direct_boost") );
00156 clsfy_direct_boost &strong_classifier = (clsfy_direct_boost&) model;
00157
00158
00159
00160 if ( max_n_clfrs_ < 0 )
00161 {
00162 vcl_cout<<"Error: clsfy_direct_boost_builder::build\n"
00163 <<"max_n_clfrs_ = "<<max_n_clfrs_<<" ie < 0\n"
00164 <<"set using set_max_n_clfrs()\n";
00165 vcl_abort();
00166 }
00167 else
00168 {
00169 vcl_cout<<"Maximum number of classifiers to be found by Adaboost ="
00170 <<max_n_clfrs_<<'\n';
00171 }
00172
00173 if ( weak_builder_ == 0 )
00174 {
00175 vcl_cout<<"Error: clsfy_direct_boost_builder::build\n"
00176 <<"weak_builder_ pointer has not been set\n"
00177 <<"need to provide a builder to build each weak classifier\n"
00178 <<"set using set_weak_builder()\n";
00179 vcl_abort();
00180 }
00181 else
00182 {
00183 vcl_cout<<"Weak learner used by AdaBoost ="
00184 <<weak_builder_->is_a()<<'\n';
00185 }
00186
00187 if ( bs_ < 0 )
00188 {
00189 vcl_cout<<"Error: clsfy_direct_boost_builder::build\n"
00190 <<"bs_ = "<<bs_<<" ie < 0\n"
00191 <<"set using set_batch_size()\n";
00192 vcl_abort();
00193 }
00194 else
00195 {
00196 vcl_cout<<"Batch size when sorting data =" <<bs_<<'\n';
00197 }
00198
00199
00200 assert(bs_>0);
00201 assert(bs_!=1);
00202 assert (max_n_clfrs_ >= 0);
00203
00204
00205
00206
00207
00208
00209
00210 unsigned long n = inputs.size();
00211
00212
00213
00214 inputs.reset();
00215 unsigned int d = inputs.current().size();
00216
00217
00218
00219
00220
00221 vcl_string temp_path= "temp.dat";
00222 mbl_file_data_collector< vnl_vector<double> >
00223 file_collector( temp_path );
00224
00225 mbl_data_collector_list< vnl_vector< double > >
00226 ram_collector;
00227
00228 mbl_data_collector<vnl_vector< double> >* collector;
00229
00230 if (save_data_to_disk_)
00231 {
00232 vcl_cout<<"saving data to disk!\n";
00233 collector= &file_collector;
00234 }
00235 else
00236 {
00237
00238 vcl_cout<<"saving data to ram!\n";
00239 collector= &ram_collector;
00240 }
00241
00242
00243
00244
00245
00246
00247 vcl_vector< vnl_vector< double > >vec(bs_);
00248
00249 vcl_cout<<"d= "<<d<<vcl_endl;
00250 unsigned int b=0;
00251 while ( b+1<d )
00252 {
00253 int r= vcl_min ( bs_, int(d-b) );
00254 assert(r>0);
00255
00256 vcl_cout<<"arranging weak classifier data = "<<b<<" to "
00257 <<(b+r)-1<<" of "<<d<<vcl_endl;
00258
00259
00260 for (int i=0; i< bs_; ++i)
00261 vec[i].set_size(n);
00262
00263
00264 inputs.reset();
00265 for (unsigned long j=0;j<n;++j)
00266 {
00267 for (int i=0; i< r; ++i)
00268 vec[i](j)=( inputs.current()[b+i] );
00269 inputs.next();
00270 }
00271
00272
00273 for (int i=0; i< r; ++i)
00274 {
00275
00276 assert (vec[i].size() == n);
00277 assert (n != 0);
00278
00279
00280 collector->record(vec[i]);
00281 }
00282
00283 b+=bs_;
00284 }
00285
00286
00287 mbl_data_wrapper< vnl_vector<double> >&
00288 wrapper=collector->data_wrapper();
00289
00290
00291
00292 wrapper.reset();
00293 assert ( wrapper.current().size() == n );
00294 assert ( d == wrapper.size() );
00295
00296
00297
00298 clsfy_classifier_1d* c1d = weak_builder_->new_classifier();
00299
00300
00301 vnl_vector<double> wts(n,1.0/n);
00302
00303
00304
00305 vcl_vector< double > errors(0);
00306 vcl_vector< vcl_vector<bool> > responses(0);
00307 vcl_vector< clsfy_classifier_1d* > classifiers(0);
00308
00309 wrapper.reset();
00310 for (unsigned int i=0; i<d; ++i )
00311 {
00312 const vnl_vector<double>& vec= wrapper.current();
00313 double error= weak_builder_->build(*c1d, vec, wts, outputs);
00314
00315 vcl_vector<bool> resp_vec(n);
00316
00317 for (unsigned long k=0; k<n;++k)
00318 {
00319 unsigned int r= c1d->classify( vec(k) );
00320 if (r==0)
00321 resp_vec[k]=false;
00322 else
00323 resp_vec[k]=true;
00324 }
00325
00326 responses.push_back( resp_vec );
00327 errors.push_back( error );
00328 classifiers.push_back( c1d->clone() );
00329
00330 wrapper.next();
00331 }
00332
00333 delete c1d;
00334
00335
00336
00337
00338
00339 vcl_vector<int> index;
00340 mbl_index_sort(errors, index);
00341
00342
00343
00344
00345 strong_classifier.clear();
00346 strong_classifier.set_n_dims(d);
00347
00348
00349 for (int k=0; k<max_n_clfrs_; ++k)
00350 {
00351 if (index.size() == 0 ) break;
00352
00353
00354 int ind= index[0];
00355 vcl_cout<<"ind= "<<ind<<", errors["<<ind<<"]= "<<errors[ind]<<'\n';
00356 if (errors[ind]> 0.5 ) break;
00357
00358 if (errors[ind]==0)
00359 strong_classifier.add_one_classifier( classifiers[ind], 1.0, ind);
00360 else
00361 strong_classifier.add_one_classifier( classifiers[ind], 1.0/errors[ind], ind);
00362
00363 if (calc_all_thresholds_)
00364 {
00365
00366
00367 double t=calc_threshold( strong_classifier, inputs, outputs );
00368 strong_classifier.add_one_threshold(t);
00369 }
00370 else
00371 {
00372
00373 strong_classifier.add_one_threshold(0.0);
00374 }
00375
00376 if (errors[ind]==0) break;
00377
00378
00379
00380 vcl_vector<int> new_index(0);
00381 vcl_vector<bool>& i_vec=responses[ind];
00382 unsigned int m=index.size();
00383 unsigned int n_rejects=0;
00384 for (unsigned int j=0; j<m; ++j)
00385 {
00386 vcl_vector<bool>& j_vec=responses[ index[j] ];
00387 double prop_same= calc_prop_same(i_vec,j_vec);
00388
00389 if ( prop_same < prop_ )
00390 new_index.push_back( index[j] );
00391 else
00392 ++n_rejects;
00393 }
00394
00395 vcl_cout<<"number of rejects due to similarity= "<<n_rejects<<vcl_endl;
00396
00397
00398
00399
00400 index= new_index;
00401
00402
00403
00404 }
00405 for (unsigned i =0; i< classifiers.size(); ++i)
00406 delete classifiers[i];
00407
00408
00409
00410 double t=calc_threshold( strong_classifier, inputs, outputs );
00411 strong_classifier.add_final_threshold(t);
00412
00413
00414
00415
00416 vcl_cout<<"calculating training error\n";
00417 return clsfy_test_error(strong_classifier, inputs, outputs);
00418 }
00419
00420
00421
00422
00423 clsfy_classifier_base* clsfy_direct_boost_builder::new_classifier() const
00424 {
00425 return new clsfy_direct_boost();
00426 }
00427
00428
00429
00430
00431 #if 0
00432
00433
00434 clsfy_direct_boost_builder::clsfy_direct_boost_builder(const clsfy_direct_boost_builder& new_b):
00435 data_ptr_(0)
00436 {
00437 *this = new_b;
00438 }
00439
00440
00441
00442
00443 clsfy_direct_boost_builder& clsfy_direct_boost_builder::operator=(const clsfy_direct_boost_builder& new_b)
00444 {
00445 if (&new_b==this) return *this;
00446
00447
00448 delete data_ptr_; data_ptr_=0;
00449
00450 if (new_b.data_ptr_)
00451 data_ptr_ = new_b.data_ptr_->clone();
00452
00453
00454 data_ = new_b.data_;
00455
00456 return *this;
00457 }
00458
00459 #endif // 0
00460
00461
00462
00463
00464 clsfy_builder_base* clsfy_direct_boost_builder::clone() const
00465 {
00466 return new clsfy_direct_boost_builder(*this);
00467 }
00468
00469
00470
00471
00472 void clsfy_direct_boost_builder::print_summary(vcl_ostream& ) const
00473 {
00474 #if 0
00475 clsfy_builder_base::print_summary(os);
00476 vsl_print_summary(os, data_);
00477 #endif
00478
00479 vcl_cerr << "clsfy_direct_boost_builder::print_summary() NYI\n";
00480 }
00481
00482
00483
00484
00485 void clsfy_direct_boost_builder::b_write(vsl_b_ostream& ) const
00486 {
00487 #if 0
00488 vsl_b_write(bfs, version_no());
00489 clsfy_builder_base::b_write(bfs);
00490 vsl_b_write(bfs, data_);
00491 #endif
00492 vcl_cerr << "clsfy_direct_boost_builder::b_write() NYI\n";
00493 }
00494
00495
00496
00497
00498 void clsfy_direct_boost_builder::b_read(vsl_b_istream& )
00499 {
00500 vcl_cerr << "clsfy_direct_boost_builder::b_read() NYI\n";
00501 #if 0
00502 if (!bfs) return;
00503
00504 short version;
00505 vsl_b_read(bfs,version);
00506 switch (version)
00507 {
00508 case 1:
00509
00510 vsl_b_read(bfs,data_);
00511 break;
00512 default:
00513 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_direct_boost_builder&)\n"
00514 << " Unknown version number "<< version << '\n';
00515 bfs.is().clear(vcl_ios::badbit);
00516 return;
00517 }
00518 #endif // 0
00519 }