Go to the documentation of this file.00001 #include "clsfy_direct_boost.h"
00002
00003
00004
00005
00006
00007
00008
00009 #include <vcl_string.h>
00010 #include <vcl_iostream.h>
00011 #include <vcl_vector.h>
00012 #include <vcl_cassert.h>
00013 #include <vcl_cmath.h>
00014 #include <vsl/vsl_binary_io.h>
00015 #include <vsl/vsl_binary_loader.h>
00016 #include <vsl/vsl_vector_io.h>
00017 #include <vnl/io/vnl_io_matrix.h>
00018 #include <vnl/io/vnl_io_vector.h>
00019
00020
00021
00022 clsfy_direct_boost::clsfy_direct_boost()
00023 : n_clfrs_used_(-1) , n_dims_(-1)
00024 {
00025 }
00026
00027 clsfy_direct_boost::clsfy_direct_boost(const clsfy_direct_boost& c)
00028 : clsfy_classifier_base()
00029 {
00030 *this = c;
00031 }
00032
00033
00034 clsfy_direct_boost& clsfy_direct_boost::operator=(const clsfy_direct_boost& c)
00035 {
00036 delete_stuff();
00037
00038 int n = c.classifier_1d_.size();
00039 classifier_1d_.resize(n);
00040 for (int i=0;i<n;++i)
00041 classifier_1d_[i] = c.classifier_1d_[i]->clone();
00042
00043 threshes_ = c.threshes_;
00044 wts_ = c.wts_;
00045 index_ = c.index_;
00046 return *this;
00047 }
00048
00049
00050 void clsfy_direct_boost::delete_stuff()
00051 {
00052 for (unsigned int i=0;i<classifier_1d_.size();++i)
00053 delete classifier_1d_[i];
00054
00055 classifier_1d_.resize(0);
00056
00057 threshes_.resize(0);
00058 wts_.resize(0);
00059 index_.resize(0);
00060 n_clfrs_used_= -1;
00061 }
00062
00063
00064 clsfy_direct_boost::~clsfy_direct_boost()
00065 {
00066 delete_stuff();
00067 }
00068
00069
00070
00071 bool clsfy_direct_boost::operator==(const clsfy_direct_boost& x) const
00072 {
00073 if (x.classifier_1d_.size() != classifier_1d_.size() ) return false;
00074 int n= x.classifier_1d_.size();
00075 for (int i=0; i<n; ++i)
00076 if (!(*(x.classifier_1d_[i]) == *(classifier_1d_[i]) )) return false;
00077
00078 return x.threshes_ == threshes_ &&
00079 x.wts_ == wts_ &&
00080 x.index_ == index_;
00081 }
00082
00083
00084
00085 void clsfy_direct_boost::set_parameters(
00086 const vcl_vector<clsfy_classifier_1d*>& classifier,
00087 const vcl_vector<double>& threshes,
00088 const vcl_vector<double>& wts,
00089 const vcl_vector<int>& index)
00090 {
00091 delete_stuff();
00092
00093 int n = classifier.size();
00094 classifier_1d_.resize(n);
00095 for (int i=0;i<n;++i)
00096 classifier_1d_[i] = classifier[i]->clone();
00097
00098 threshes_ = threshes;
00099 wts_ = wts;
00100 index_= index;
00101 }
00102
00103
00104
00105 void clsfy_direct_boost::clear()
00106 {
00107 delete_stuff();
00108 }
00109
00110
00111
00112
00113
00114 void clsfy_direct_boost::add_one_classifier(clsfy_classifier_1d* c1d,
00115 double wt,
00116 int index)
00117 {
00118 classifier_1d_.push_back(c1d->clone());
00119 wts_.push_back(wt);
00120 index_.push_back(index);
00121 n_clfrs_used_=wts_.size();
00122 }
00123
00124
00125
00126 void clsfy_direct_boost::add_one_threshold(double thresh)
00127 {
00128 threshes_.push_back(thresh);
00129 }
00130
00131
00132
00133 void clsfy_direct_boost::add_final_threshold(double thresh)
00134 {
00135 int n= threshes_.size();
00136 threshes_[n-1]= thresh;
00137 }
00138
00139
00140
00141
00142 unsigned clsfy_direct_boost::classify(const vnl_vector<double> &v) const
00143 {
00144
00145
00146 assert ( n_clfrs_used_ >= 0);
00147 assert ( (unsigned)n_clfrs_used_ <= wts_.size() );
00148 assert ( n_dims_ >= 0);
00149 assert ( v.size() == (unsigned)n_dims_ );
00150
00151
00152 double sum = 0.0;
00153 for (int i=0;i<n_clfrs_used_;++i)
00154 sum+= wts_[i]* classifier_1d_[i]->log_l( v[ index_[i] ] );
00155
00156
00157 if (sum < threshes_[n_clfrs_used_-1] ) return 1;
00158 return 0;
00159 }
00160
00161
00162
00163
00164
00165 void clsfy_direct_boost::class_probabilities(vcl_vector<double> &outputs,
00166 const vnl_vector<double> &input) const
00167 {
00168 outputs.resize(1);
00169 outputs[0] = 1.0 / (1.0 + vcl_exp(-log_l(input)));
00170 }
00171
00172
00173
00174
00175
00176 double clsfy_direct_boost::log_l(const vnl_vector<double> &v) const
00177 {
00178 assert ( n_clfrs_used_ >= 0);
00179 assert ( (unsigned)n_clfrs_used_ <= wts_.size() );
00180
00181
00182 double sum = 0.0;
00183 for (int i=0;i<n_clfrs_used_;++i)
00184 sum+= wts_[i]* classifier_1d_[i]->log_l( v[ index_[i] ] );
00185
00186
00187 return sum;
00188
00189 }
00190
00191
00192
00193 vcl_string clsfy_direct_boost::is_a() const
00194 {
00195 return vcl_string("clsfy_direct_boost");
00196 }
00197
00198
00199
00200 bool clsfy_direct_boost::is_class(vcl_string const& s) const
00201 {
00202 return s == clsfy_direct_boost::is_a() || clsfy_classifier_base::is_class(s);
00203 }
00204
00205
00206
00207
00208 void clsfy_direct_boost::print_summary(vcl_ostream& os) const
00209 {
00210 int n = wts_.size();
00211 assert( wts_.size() == index_.size() );
00212 os<<'\n';
00213 for (int i=0;i<n;++i)
00214 {
00215 os<<" Weights: "<<wts_[i]
00216 <<" Index: "<<index_[i]
00217 <<" Total Threshold: "<<threshes_[i]
00218 <<" Classifier: "<<classifier_1d_[i]<<'\n';
00219 }
00220 }
00221
00222
00223
00224 short clsfy_direct_boost::version_no() const
00225 {
00226 return 1;
00227 }
00228
00229
00230
00231 void clsfy_direct_boost::b_write(vsl_b_ostream& bfs) const
00232 {
00233 vsl_b_write(bfs,version_no());
00234 vsl_b_write(bfs,classifier_1d_);
00235 vsl_b_write(bfs,threshes_);
00236 vsl_b_write(bfs,wts_);
00237 vsl_b_write(bfs,index_);
00238 }
00239
00240
00241
00242 void clsfy_direct_boost::b_read(vsl_b_istream& bfs)
00243 {
00244 if (!bfs) return;
00245
00246 delete_stuff();
00247
00248 short version;
00249 vsl_b_read(bfs,version);
00250 switch (version)
00251 {
00252 case (1):
00253 vsl_b_read(bfs,classifier_1d_);
00254 vsl_b_read(bfs,threshes_);
00255 vsl_b_read(bfs,wts_);
00256 vsl_b_read(bfs,index_);
00257
00258
00259 n_clfrs_used_= index_.size();
00260
00261 break;
00262 default:
00263 vcl_cerr << "I/O ERROR: clsfy_direct_boost::b_read(vsl_b_istream&)\n"
00264 << " Unknown version number "<< version << '\n';
00265 bfs.is().clear(vcl_ios::badbit);
00266 }
00267 }