Go to the documentation of this file.00001
00002
00003 #include "clsfy_rbf_svm_smo_1_builder.h"
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <vcl_string.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_sstream.h>
00015 #include <vcl_algorithm.h>
00016 #include <vcl_cassert.h>
00017 #include <vul/vul_string.h>
00018
00019 #include <mbl/mbl_data_wrapper.h>
00020 #include <mbl/mbl_parse_block.h>
00021 #include <mbl/mbl_read_props.h>
00022
00023 #include <clsfy/clsfy_smo_1.h>
00024
00025
00026
00027 inline int class_to_svm_target (unsigned v) {return v==1?1:-1;}
00028
00029
00030
00031
00032 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00033 mbl_data_wrapper<vnl_vector<double> >& inputs,
00034 const vcl_vector<unsigned> &outputs) const
00035 {
00036 inputs.reset();
00037
00038 const unsigned int nSamples = inputs.size();
00039 assert(outputs.size() == nSamples);
00040 assert(*vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00041
00042 assert(classifier.is_class("clsfy_rbf_svm"));
00043 clsfy_rbf_svm &svm = static_cast<clsfy_rbf_svm &>(classifier);
00044
00045 clsfy_smo_1_rbf svAPI;
00046 vcl_vector<int> targets(nSamples);
00047 vcl_transform(outputs.begin(), outputs.end(),
00048 targets.begin(), class_to_svm_target);
00049
00050 svAPI.set_data(inputs, targets);
00051
00052
00053
00054 svAPI.set_C(boundC_);
00055 svAPI.set_gamma(1.0/(2.0*rbf_width_*rbf_width_));
00056
00057 svAPI.calc();
00058
00059
00060
00061 {
00062 vcl_vector<vnl_vector<double> > supportVectors;
00063 const vnl_vector<double> &allAlphas = svAPI.lagrange_mults();
00064 vcl_vector<double> alphas;
00065 vcl_vector<unsigned> labels;
00066 for (unsigned i=0; i<nSamples; ++i)
00067 if (allAlphas[i]!=0.0)
00068 {
00069 alphas.push_back(allAlphas[i]);
00070 labels.push_back(outputs[i]);
00071 inputs.set_index(i);
00072 supportVectors.push_back(inputs.current());
00073 }
00074 svm.set(supportVectors, alphas, labels, rbf_width_, svAPI.bias());
00075 }
00076
00077 return svAPI.error_rate();
00078 }
00079
00080
00081
00082
00083
00084 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00085 mbl_data_wrapper<vnl_vector<double> >& inputs,
00086 unsigned nClasses,
00087 const vcl_vector<unsigned> &outputs) const
00088 {
00089 assert(nClasses == 1);
00090 return build(classifier, inputs, outputs);
00091 }
00092
00093
00094
00095 double clsfy_rbf_svm_smo_1_builder::rbf_width() const
00096 {
00097 return rbf_width_;
00098 }
00099
00100
00101
00102 void clsfy_rbf_svm_smo_1_builder::set_rbf_width(double rbf_width)
00103 {
00104 rbf_width_ = rbf_width;
00105 }
00106
00107
00108 vcl_string clsfy_rbf_svm_smo_1_builder::is_a() const
00109 {
00110 return vcl_string("clsfy_rbf_svm_smo_1_builder");
00111 }
00112
00113
00114
00115 bool clsfy_rbf_svm_smo_1_builder::is_class(vcl_string const& s) const
00116 {
00117 return s == clsfy_rbf_svm_smo_1_builder::is_a() || clsfy_builder_base::is_class(s);
00118 }
00119
00120
00121
00122 short clsfy_rbf_svm_smo_1_builder::version_no() const
00123 {
00124 return 1;
00125 }
00126
00127
00128
00129 clsfy_builder_base* clsfy_rbf_svm_smo_1_builder::clone() const
00130 {
00131 return new clsfy_rbf_svm_smo_1_builder(*this);
00132 }
00133
00134
00135
00136 void clsfy_rbf_svm_smo_1_builder::print_summary(vcl_ostream& os) const
00137 {
00138
00139 os << "RBF width = " << rbf_width_ << ", bounds = " << boundC_;
00140 }
00141
00142
00143
00144 void clsfy_rbf_svm_smo_1_builder::b_write(vsl_b_ostream& bfs) const
00145 {
00146 vsl_b_write(bfs,version_no());
00147 vsl_b_write(bfs,boundC_);
00148 vsl_b_write(bfs,rbf_width_);
00149 }
00150
00151
00152
00153 void clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream& bfs)
00154 {
00155 if (!bfs) return;
00156
00157 short version;
00158 vsl_b_read(bfs,version);
00159 switch (version)
00160 {
00161 case (1):
00162 vsl_b_read(bfs,boundC_);
00163 vsl_b_read(bfs,rbf_width_);
00164 break;
00165 default:
00166 vcl_cerr << "I/O ERROR: clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream&)\n"
00167 << " Unknown version number "<< version << '\n';
00168 bfs.is().clear(vcl_ios::badbit);
00169 return;
00170 }
00171 }
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187 void clsfy_rbf_svm_smo_1_builder::config(vcl_istream &as)
00188 {
00189 vcl_string s = mbl_parse_block(as);
00190
00191 vcl_istringstream ss(s);
00192 mbl_read_props_type props = mbl_read_props_ws(ss);
00193
00194 {
00195 boundC_= vul_string_atof(props.get_optional_property("boundC", "0.0"));
00196 rbf_width_= vul_string_atof(props.get_optional_property("rbf_width", "0.0"));
00197 }
00198
00199
00200 mbl_read_props_look_for_unused_props(
00201 "clsfy_rbf_svm_smo_1_builder::config", props, mbl_read_props_type());
00202 }