Go to the documentation of this file.00001
00002 #include "clsfy_binary_threshold_1d_builder.h"
00003
00004
00005
00006
00007
00008 #include <vcl_iostream.h>
00009 #include <vcl_string.h>
00010 #include <vcl_cassert.h>
00011 #include <vsl/vsl_binary_loader.h>
00012 #include <vnl/vnl_double_2.h>
00013 #include <clsfy/clsfy_builder_1d.h>
00014 #include <clsfy/clsfy_binary_threshold_1d.h>
00015 #include <vcl_algorithm.h>
00016
00017
00018
00019 clsfy_binary_threshold_1d_builder::clsfy_binary_threshold_1d_builder()
00020 {
00021 }
00022
00023
00024
00025 clsfy_binary_threshold_1d_builder::~clsfy_binary_threshold_1d_builder()
00026 {
00027 }
00028
00029
00030
00031 short clsfy_binary_threshold_1d_builder::version_no() const
00032 {
00033 return 1;
00034 }
00035
00036
00037
00038
00039 clsfy_classifier_1d* clsfy_binary_threshold_1d_builder::new_classifier() const
00040 {
00041 return new clsfy_binary_threshold_1d();
00042 }
00043
00044
00045
00046
00047
00048
00049
00050
00051 double clsfy_binary_threshold_1d_builder::build(clsfy_classifier_1d& classifier,
00052 const vnl_vector<double>& egs,
00053 const vnl_vector<double>& wts,
00054 const vcl_vector<unsigned> &outputs) const
00055 {
00056
00057 assert(classifier.is_class("clsfy_binary_threshold_1d"));
00058
00059 unsigned int n= egs.size();
00060 assert ( wts.size() == n );
00061 assert ( outputs.size() == n );
00062
00063
00064 vcl_vector<vbl_triple<double,int,int> > data;
00065
00066 vbl_triple<double,int,int> t;
00067
00068 for (unsigned int i=0;i<n;++i)
00069 {
00070 t.first=egs(i);
00071 t.second= outputs[i];
00072 t.third = i;
00073 data.push_back(t);
00074 }
00075
00076 vbl_triple<double,int,int> *data_ptr=&data[0];
00077 vcl_sort(data_ptr,data_ptr+n);
00078 return build_from_sorted_data(classifier, &data[0], wts);
00079 }
00080
00081
00082
00083
00084
00085 double clsfy_binary_threshold_1d_builder::build(clsfy_classifier_1d& classifier,
00086 vnl_vector<double>& egs0,
00087 vnl_vector<double>& wts0,
00088 vnl_vector<double>& egs1,
00089 vnl_vector<double>& wts1) const
00090 {
00091
00092 assert(classifier.is_class("clsfy_binary_threshold_1d"));
00093
00094 vcl_vector<vbl_triple<double,int,int> > data;
00095 unsigned int n0 = egs0.size();
00096 unsigned int n1 = egs1.size();
00097 vnl_vector<double> wts(n0+n1);
00098 vbl_triple<double,int,int> t;
00099
00100 for (unsigned int i=0;i<n0;++i)
00101 {
00102 t.first=egs0[i];
00103 t.second=0;
00104 t.third = i;
00105 wts(i)= wts0[i];
00106 data.push_back(t);
00107 }
00108
00109
00110 for (unsigned int i=0;i<n1;++i)
00111 {
00112 t.first=egs1[i];
00113 t.second=1;
00114 t.third = i+n0;
00115 wts(i+n0)= wts1[i];
00116 data.push_back(t);
00117 }
00118
00119 unsigned int n=n0+n1;
00120
00121 vbl_triple<double,int,int> *data_ptr=&data[0];
00122 vcl_sort(data_ptr,data_ptr+n);
00123
00124 return build_from_sorted_data(classifier,&data[0], wts);
00125 }
00126
00127
00128
00129
00130 double clsfy_binary_threshold_1d_builder::build_from_sorted_data(
00131 clsfy_classifier_1d& classifier,
00132 const vbl_triple<double,int,int> *data,
00133 const vnl_vector<double>& wts
00134 ) const
00135 {
00136
00137
00138
00139
00140
00141
00142
00143 unsigned int n=wts.size();
00144 double tot_wts0=0.0, tot_wts1=0.0;
00145 for (unsigned int i=0;i<n;++i)
00146 if (data[i].second==0)
00147 tot_wts0+=wts(data[i].third);
00148 else
00149 tot_wts1+=wts(data[i].third);
00150
00151 double e0=0.0, e1=0.0, min_err=2.0;
00152 double etot0,etot1;
00153 unsigned int index=n; int polarity=0;
00154 for (unsigned int i=0;i<n;++i)
00155 {
00156 if (data[i].second==0)
00157 e0+=wts(data[i].third);
00158 else
00159 e1+=wts(data[i].third);
00160
00161 etot0=(tot_wts0-e0) +e1;
00162 etot1=(tot_wts1-e1) +e0;
00163
00164 if ( etot0< min_err)
00165 {
00166
00167
00168 polarity=+1;
00169 index=i;
00170
00171 min_err= etot0;
00172 }
00173
00174 if ( etot1< min_err)
00175 {
00176
00177
00178 polarity=-1;
00179 index=i;
00180
00181 min_err= etot1;
00182 }
00183 }
00184
00185 assert ( index!=n );
00186
00187
00188 double threshold;
00189 if ( index+1==n )
00190 threshold=data[index].first+0.01;
00191 else
00192 threshold=(data[index].first+data[index+1].first)/2;
00193
00194
00195 vnl_double_2 params(polarity, threshold*polarity);
00196 classifier.set_params(params.as_vector());
00197 return min_err;
00198 }
00199
00200
00201
00202 vcl_string clsfy_binary_threshold_1d_builder::is_a() const
00203 {
00204 return vcl_string("clsfy_binary_threshold_1d_builder");
00205 }
00206
00207 bool clsfy_binary_threshold_1d_builder::is_class(vcl_string const& s) const
00208 {
00209 return s == clsfy_binary_threshold_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
00210 }
00211
00212
00213
00214 clsfy_builder_1d* clsfy_binary_threshold_1d_builder::clone() const
00215 {
00216 return new clsfy_binary_threshold_1d_builder(*this);
00217 }
00218
00219
00220
00221
00222 void clsfy_binary_threshold_1d_builder::print_summary(vcl_ostream& ) const
00223 {
00224 }
00225
00226
00227
00228
00229 void clsfy_binary_threshold_1d_builder::b_write(vsl_b_ostream& bfs) const
00230 {
00231 short version_no=1;
00232 vsl_b_write(bfs, version_no);
00233 }
00234
00235
00236
00237
00238 void clsfy_binary_threshold_1d_builder::b_read(vsl_b_istream& bfs)
00239 {
00240 if (!bfs) return;
00241
00242 short version;
00243 vsl_b_read(bfs,version);
00244 switch (version)
00245 {
00246 case 1:
00247 break;
00248 default:
00249 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_builder&)\n"
00250 << " Unknown version number "<< version << '\n';
00251 bfs.is().clear(vcl_ios::badbit);
00252 return;
00253 }
00254 }