Go to the documentation of this file.00001
00002 #include "clsfy_mean_square_1d_builder.h"
00003
00004
00005
00006
00007
00008 #include <vcl_cmath.h>
00009 #include <vcl_iostream.h>
00010 #include <vcl_string.h>
00011 #include <vcl_cassert.h>
00012 #include <vcl_cstdlib.h>
00013 #include <vsl/vsl_binary_loader.h>
00014 #include <vnl/vnl_double_2.h>
00015 #include <clsfy/clsfy_builder_1d.h>
00016 #include <clsfy/clsfy_mean_square_1d.h>
00017 #include <vcl_algorithm.h>
00018
00019
00020
00021 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder()
00022 {
00023 }
00024
00025
00026
00027 clsfy_mean_square_1d_builder::~clsfy_mean_square_1d_builder()
00028 {
00029 }
00030
00031
00032
00033 short clsfy_mean_square_1d_builder::version_no() const
00034 {
00035 return 1;
00036 }
00037
00038
00039
00040
00041 clsfy_classifier_1d* clsfy_mean_square_1d_builder::new_classifier() const
00042 {
00043 return new clsfy_mean_square_1d();
00044 }
00045
00046
00047
00048
00049
00050
00051
00052
00053 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
00054 const vnl_vector<double>& egs,
00055 const vnl_vector<double>& wts,
00056 const vcl_vector<unsigned> &outputs) const
00057 {
00058
00059 assert(classifier.is_class("clsfy_mean_square_1d"));
00060
00061 unsigned int n = egs.size();
00062 assert ( wts.size() == n );
00063 assert ( outputs.size() == n );
00064
00065
00066 double wm_pos= 0.0;
00067 double tot_pos_wts=0.0, tot_neg_wts=0.0;
00068 unsigned int n_pos=0, n_neg=0;
00069 for (unsigned int i=0; i<n; ++i)
00070 {
00071 #ifdef DEBUG
00072 vcl_cout<<"egs["<<i<<"]= "<<egs[i]<<vcl_endl
00073 <<"wts["<<i<<"]= "<<wts[i]<<vcl_endl
00074 <<"outputs["<<i<<"]= "<<outputs[i]<<vcl_endl;
00075 #endif
00076 if ( outputs[i] == 1 )
00077 {
00078
00079 wm_pos+= wts(i)*egs(i);
00080 tot_pos_wts+= wts(i);
00081 ++n_pos;
00082 }
00083 else
00084 {
00085 tot_neg_wts+= wts(i);
00086 ++n_neg;
00087 }
00088 }
00089
00090 assert( n_pos+n_neg== n );
00091 wm_pos/=tot_pos_wts;
00092 #ifdef DEBUG
00093 vcl_cout<<"wm_pos= "<<wm_pos<<vcl_endl;
00094 #endif
00095
00096 vcl_vector<vbl_triple<double,int,int> > data;
00097
00098 vbl_triple<double,int,int> t;
00099
00100 for (unsigned int i=0;i<n;++i)
00101 {
00102 double k= wm_pos-egs[i];
00103 t.first=k*k;
00104 t.second= outputs[i];
00105 t.third = i;
00106 data.push_back(t);
00107 }
00108
00109 vbl_triple<double,int,int> *data_ptr=&data[0];
00110 vcl_sort(data_ptr,data_ptr+n);
00111
00112 double wt_pos=0;
00113 double wt_neg=0;
00114 double min_error= 1000000;
00115 double min_thresh= -1;
00116 for (unsigned int i=0;i<n;++i)
00117 {
00118 if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
00119 else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
00120 else
00121 {
00122 vcl_cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
00123 <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
00124 <<"data.second="<<data[i].second<<vcl_endl;
00125 vcl_abort();
00126 }
00127 double error= tot_pos_wts-wt_pos+wt_neg;
00128 #ifdef DEBUG
00129 vcl_cout<<"data[i].first= "<<data[i].first<<vcl_endl
00130 <<"data[i].second= "<<data[i].second<<vcl_endl
00131 <<"data[i].third= "<<data[i].third<<vcl_endl
00132
00133 <<"wt_pos= "<<wt_pos<<vcl_endl
00134 <<"tot_wts1= "<<tot_wts1<<vcl_endl
00135 <<"wt_neg= "<<wt_neg<<vcl_endl
00136
00137 <<"error= "<<error<<vcl_endl;
00138 #endif
00139 if ( error< min_error )
00140 {
00141 min_error= error;
00142 min_thresh = data[i].first + 0.001 ;
00143 }
00144 }
00145
00146 assert( vcl_fabs (wt_pos - tot_pos_wts) < 1e-9 );
00147 assert( vcl_fabs (wt_neg - tot_neg_wts) < 1e-9 );
00148 #ifdef DEBUG
00149 vcl_cout<<"min_error= "<<min_error<<vcl_endl
00150 <<"min_thresh= "<<min_thresh<<vcl_endl;
00151 #endif
00152
00153 classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
00154 return min_error;
00155 }
00156
00157
00158
00159
00160
00161 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
00162 vnl_vector<double>& egs0,
00163 vnl_vector<double>& wts0,
00164 vnl_vector<double>& egs1,
00165 vnl_vector<double>& wts1) const
00166 {
00167
00168 assert(classifier.is_class("clsfy_mean_square_1d"));
00169
00170
00171
00172 unsigned int n0 = egs0.size();
00173 unsigned int n1 = egs1.size();
00174 assert (wts0.size() == n0 );
00175 assert (wts1.size() == n1 );
00176
00177
00178 double tot_wts1= wts1.mean()*n1;
00179 double wm_pos=0.0;
00180 for (unsigned int i=0; i< n1; ++i)
00181 {
00182 wm_pos+= wts1(i)*egs1(i);
00183 #ifdef DEBUG
00184 vcl_cout<<"egs1("<<i<<")= "<<egs1(i)<<vcl_endl
00185 <<"wts1("<<i<<")= "<<wts1(i)<<vcl_endl;
00186 #endif
00187 }
00188 wm_pos/=tot_wts1;
00189
00190 vcl_cout<<"wm_pos= "<<wm_pos<<vcl_endl;
00191
00192 vcl_vector<vbl_triple<double,int,int> > data;
00193
00194 vnl_vector<double> wts(n0+n1);
00195 vbl_triple<double,int,int> t;
00196
00197 for (unsigned int i=0;i<n0;++i)
00198 {
00199 double k= wm_pos-egs0[i];
00200 t.first=k*k;
00201 t.second=0;
00202 t.third = i;
00203 wts(i)= wts0[i];
00204 data.push_back(t);
00205 }
00206
00207
00208 for (unsigned int i=0;i<n1;++i)
00209 {
00210 double k= wm_pos-egs1[i];
00211 t.first=k*k;
00212 t.second=1;
00213 t.third = i+n0;
00214 wts(i+n0)= wts1[i];
00215 data.push_back(t);
00216 }
00217
00218 unsigned int n=n0+n1;
00219
00220 vbl_triple<double,int,int> *data_ptr=&data[0];
00221 vcl_sort(data_ptr,data_ptr+n);
00222
00223 double wt_pos=0;
00224 double wt_neg=0;
00225 double min_error= 1000000;
00226 double min_thresh= -1;
00227 for (unsigned int i=0;i<n;++i)
00228 {
00229 if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
00230 else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
00231 else
00232 {
00233 vcl_cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
00234 <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
00235 <<"data.second="<<data[i].second<<vcl_endl;
00236 vcl_abort();
00237 }
00238 double error= tot_wts1-wt_pos+wt_neg;
00239 #ifdef DEBUG
00240 vcl_cout<<"data[i].first= "<<data[i].first<<vcl_endl
00241 <<"data[i].second= "<<data[i].second<<vcl_endl
00242 <<"data[i].third= "<<data[i].third<<vcl_endl
00243
00244 <<"wt_pos= "<<wt_pos<<vcl_endl
00245 <<"tot_wts1= "<<tot_wts1<<vcl_endl
00246 <<"wt_neg= "<<wt_neg<<vcl_endl
00247
00248 <<"error= "<<error<<vcl_endl;
00249 #endif
00250 if ( error< min_error )
00251 {
00252 min_error= error;
00253 min_thresh = data[i].first + 0.001 ;
00254 }
00255 }
00256
00257 assert( vcl_fabs (wt_pos - tot_wts1) < 1e-9 );
00258 assert( vcl_fabs (wt_neg - wts0.mean()*n0) < 1e-9 );
00259 vcl_cout<<"min_error= "<<min_error<<vcl_endl
00260 <<"min_thresh= "<<min_thresh<<vcl_endl;
00261
00262
00263 classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
00264 return min_error;
00265 }
00266
00267
00268
00269
00270 double clsfy_mean_square_1d_builder::build_from_sorted_data(
00271 clsfy_classifier_1d& ,
00272 const vbl_triple<double,int,int>* ,
00273 const vnl_vector<double>&
00274 ) const
00275 {
00276 vcl_cout<<"ERROR: clsfy_mean_square_1d_builder::build_from_sorted_data()\n"
00277 <<"Function not implemented because can't use pre-sorted data\n"
00278 <<"the weighted mean of the data is needed to calc the ordering!\n";
00279 vcl_abort();
00280
00281 return 0.0;
00282 }
00283
00284
00285
00286 vcl_string clsfy_mean_square_1d_builder::is_a() const
00287 {
00288 return vcl_string("clsfy_mean_square_1d_builder");
00289 }
00290
00291 bool clsfy_mean_square_1d_builder::is_class(vcl_string const& s) const
00292 {
00293 return s == clsfy_mean_square_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
00294 }
00295
00296
00297
00298 #if 0 // two functions commented out
00299
00300
00301 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder(
00302 const clsfy_mean_square_1d_builder& new_b) :
00303 data_ptr_(0)
00304 {
00305 *this = new_b;
00306 }
00307
00308
00309
00310
00311 clsfy_mean_square_1d_builder&
00312 clsfy_mean_square_1d_builder::operator=(const clsfy_mean_square_1d_builder& new_b)
00313 {
00314 if (&new_b==this) return *this;
00315
00316
00317 delete data_ptr_; data_ptr_=0;
00318
00319 if (new_b.data_ptr_)
00320 data_ptr_ = new_b.data_ptr_->clone();
00321
00322
00323 data_ = new_b.data_;
00324
00325 return *this;
00326 }
00327
00328 #endif // 0
00329
00330
00331
00332 clsfy_builder_1d* clsfy_mean_square_1d_builder::clone() const
00333 {
00334 return new clsfy_mean_square_1d_builder(*this);
00335 }
00336
00337
00338
00339
00340 void clsfy_mean_square_1d_builder::print_summary(vcl_ostream& ) const
00341 {
00342
00343
00344
00345 vcl_cerr << "clsfy_mean_square_1d_builder::print_summary() NYI\n";
00346 }
00347
00348
00349
00350
00351 void clsfy_mean_square_1d_builder::b_write(vsl_b_ostream& ) const
00352 {
00353
00354
00355
00356 vcl_cerr << "clsfy_mean_square_1d_builder::b_write() NYI\n";
00357 }
00358
00359
00360
00361
00362 void clsfy_mean_square_1d_builder::b_read(vsl_b_istream& )
00363 {
00364 vcl_cerr << "clsfy_mean_square_1d_builder::b_read() NYI\n";
00365 #if 0
00366 if (!bfs) return;
00367
00368 short version;
00369 vsl_b_read(bfs,version);
00370 switch (version)
00371 {
00372 case (1):
00373
00374 vsl_b_read(bfs,data_);
00375 break;
00376 default:
00377 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_mean_square_1d_builder&)\n"
00378 << " Unknown version number "<< version << '\n';
00379 bfs.is().clear(vcl_ios::badbit);
00380 return;
00381 }
00382 #endif
00383 }