Go to the documentation of this file.00001
00002 #include "clsfy_binary_threshold_1d_gini_builder.h"
00003
00004
00005
00006
00007 #include <vcl_iostream.h>
00008 #include <vcl_string.h>
00009 #include <vcl_cassert.h>
00010 #include <vsl/vsl_binary_loader.h>
00011 #include <vnl/vnl_double_2.h>
00012 #include <clsfy/clsfy_builder_1d.h>
00013 #include <clsfy/clsfy_binary_threshold_1d.h>
00014 #include <vcl_algorithm.h>
00015
00016
00017
00018
00019
00020
00021
00022
00023 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder()
00024 {
00025 }
00026
00027
00028
00029 clsfy_binary_threshold_1d_gini_builder::~clsfy_binary_threshold_1d_gini_builder()
00030 {
00031 }
00032
00033
00034
00035 short clsfy_binary_threshold_1d_gini_builder::version_no() const
00036 {
00037 return 1;
00038 }
00039
00040
00041
00042
00043 clsfy_classifier_1d* clsfy_binary_threshold_1d_gini_builder::new_classifier() const
00044 {
00045 return new clsfy_binary_threshold_1d();
00046 }
00047
00048
00049
00050
00051
00052
00053
00054
00055 double clsfy_binary_threshold_1d_gini_builder::build_gini(clsfy_classifier_1d& classifier,
00056 const vnl_vector<double>& inputs,
00057 const vcl_vector<unsigned> &outputs) const
00058 {
00059 assert(classifier.is_class("clsfy_binary_threshold_1d"));
00060
00061 unsigned n = inputs.size();
00062 assert ( outputs.size() == n );
00063
00064
00065 vcl_vector<vbl_triple<double,int,int> > data;
00066 data.reserve(n);
00067
00068
00069 vcl_vector<unsigned >::const_iterator classIter=outputs.begin();
00070 vnl_vector<double >::const_iterator inputIter=inputs.begin();
00071 vnl_vector<double >::const_iterator inputIterEnd=inputs.end();
00072 vbl_triple<double,int,int> t;
00073 unsigned i=0;
00074 while (inputIter != inputIterEnd)
00075 {
00076 t.first = *inputIter++;
00077 t.second=*classIter++;
00078 t.third = i++;
00079 data.push_back(t);
00080 }
00081
00082 assert(i==inputs.size());
00083
00084 vcl_sort(data.begin(),data.end());
00085 return build_gini_from_sorted_data(static_cast<clsfy_classifier_1d&>(classifier), data);
00086 }
00087
00088
00089
00090
00091
00092
00093 double clsfy_binary_threshold_1d_gini_builder::build_gini_from_sorted_data(
00094 clsfy_classifier_1d& classifier,
00095 const vcl_vector<vbl_triple<double,int,int> >& data) const
00096 {
00097
00098
00099
00100
00101
00102
00103 const double epsilon=1.0E-20;
00104 if (vcl_fabs(data.front().first-data.back().first)<epsilon)
00105 {
00106 vcl_cerr<<"WARNING - clsfy_binary_threshold_1d_gini_builder::build_from_sorted_data - homogeneous data - cannot split\n";
00107 int polarity=1;
00108 double threshold=data[0].first;
00109 vnl_double_2 params(polarity, threshold*polarity);
00110 classifier.set_params(params.as_vector());
00111 return 1.0;
00112 }
00113
00114 unsigned int ntot=data.size();
00115 double dntot=double (ntot);
00116 vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIter=data.begin();
00117 vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterEnd=data.end();
00118 unsigned n0Tot=0;
00119 unsigned n1Tot=0;
00120 while (dataIter != dataIterEnd)
00121 {
00122 if (dataIter->second==0)
00123 ++n0Tot;
00124 else
00125 ++n1Tot;
00126 ++dataIter;
00127 }
00128
00129 double parentImp=0.0;
00130
00131
00132 double p=double (n0Tot)/dntot;
00133 parentImp=2.0*p*(1-p);
00134
00135 dataIter=data.begin();
00136 double s=dataIter->first-epsilon;
00137 double deltaImpBest= -1.0;
00138 double sbest=s;
00139
00140 unsigned nL0=0;
00141 unsigned nL1=0;
00142 unsigned nR0=n0Tot;
00143 unsigned nR1=n1Tot;
00144 double parity=1.0;
00145 while (dataIter != dataIterEnd)
00146 {
00147 s=dataIter->first;
00148 vcl_vector<vbl_triple<double,int,int> >::const_iterator dataIterNext=dataIter;
00149
00150
00151 while (dataIterNext != dataIterEnd && (dataIterNext->first-s)<epsilon)
00152 {
00153 if (dataIterNext->second==0)
00154 {
00155 ++nL0;
00156 --nR0;
00157 }
00158 else
00159 {
00160 ++nL1;
00161 --nR1;
00162 }
00163 ++dataIterNext;
00164 }
00165
00166 unsigned nLTot=nL0+nL1;
00167 unsigned nRTot=nR0+nR1;
00168 double probL=double(nL0)/double(nLTot);
00169 double probR=double(nR1)/double(nRTot);
00170
00171 double impL=2.0*probL*(1-probL);
00172 double impR=2.0*probR*(1-probR);
00173
00174
00175 double pL=double (nLTot)/dntot;
00176 double pR=1.0-pL;
00177
00178 double deltaImp=parentImp-(pL*impL + pR*impR);
00179 if (deltaImp>deltaImpBest)
00180 {
00181 deltaImpBest=deltaImp;
00182 sbest=s;
00183 if (nR1>=nL1)
00184 parity=1;
00185 else
00186 parity=-1;
00187 }
00188
00189 dataIter=dataIterNext;
00190 }
00191
00192 double threshold=sbest;
00193
00194
00195 vnl_double_2 params(parity, threshold*parity);
00196 classifier.set_params(params.as_vector());
00197 return -deltaImpBest;
00198 }
00199
00200
00201
00202 vcl_string clsfy_binary_threshold_1d_gini_builder::is_a() const
00203 {
00204 return vcl_string("clsfy_binary_threshold_1d_gini_builder");
00205 }
00206
00207 bool clsfy_binary_threshold_1d_gini_builder::is_class(vcl_string const& s) const
00208 {
00209 return s == clsfy_binary_threshold_1d_gini_builder::is_a() || clsfy_builder_1d::is_class(s);
00210 }
00211
00212 # if 0
00213
00214
00215
00216
00217 clsfy_binary_threshold_1d_gini_builder::clsfy_binary_threshold_1d_gini_builder(
00218 const clsfy_binary_threshold_1d_gini_builder& new_b) :
00219 data_ptr_(0)
00220 {
00221 *this = new_b;
00222 }
00223
00224
00225
00226
00227 clsfy_binary_threshold_1d_gini_builder&
00228 clsfy_binary_threshold_1d_gini_builder::operator=(const clsfy_binary_threshold_1d_gini_builder& new_b)
00229 {
00230 if (&new_b==this) return *this;
00231
00232 static_cast<clsfy_binary_threshold_1d_builder&>(*this)=
00233 static_cast<const clsfy_binary_threshold_1d_builder&> (new_b);
00234
00235 return *this;
00236 }
00237 #endif
00238
00239
00240
00241 void clsfy_binary_threshold_1d_gini_builder::print_summary(vcl_ostream& os) const
00242 {
00243 os<<"clsfy_binary_threshold_1d_gini_builder"<<vcl_endl;
00244 }
00245
00246
00247
00248
00249 clsfy_builder_1d* clsfy_binary_threshold_1d_gini_builder::clone() const
00250 {
00251 return new clsfy_binary_threshold_1d_gini_builder(*this);
00252 }
00253
00254
00255
00256 void clsfy_binary_threshold_1d_gini_builder::b_write(vsl_b_ostream& bfs) const
00257 {
00258 vsl_b_write(bfs, version_no());
00259 clsfy_binary_threshold_1d_builder::b_write(bfs);
00260 }
00261
00262
00263
00264
00265 void clsfy_binary_threshold_1d_gini_builder::b_read(vsl_b_istream& bfs)
00266 {
00267 if (!bfs) return;
00268
00269 short version;
00270 vsl_b_read(bfs,version);
00271 switch (version)
00272 {
00273 case (1):
00274 clsfy_binary_threshold_1d_builder::b_read(bfs);
00275 break;
00276 default:
00277 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_gini_builder&)\n"
00278 << " Unknown version number "<< version << '\n';
00279 bfs.is().clear(vcl_ios::badbit);
00280 return;
00281 }
00282 }