Go to the documentation of this file.00001
00002 #include "clsfy_classifier_base.h"
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #include <vcl_iostream.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_vector.h>
00016 #include <vsl/vsl_indent.h>
00017 #include <vsl/vsl_binary_loader.h>
00018
00019
00020
00021 unsigned clsfy_classifier_base::classify(const vnl_vector<double> &input) const
00022 {
00023 unsigned N = n_classes();
00024
00025 vcl_vector<double> probs;
00026 class_probabilities(probs, input);
00027
00028 if (N == 1)
00029 {
00030 if (probs[0] > 0.5)
00031 return 1u;
00032 else return 0u;
00033 }
00034 else
00035 {
00036 unsigned bestIndex = 0;
00037 unsigned i = 1;
00038 double bestProb = probs[bestIndex];
00039
00040 while (i < N)
00041 {
00042 if (probs[i] > bestProb)
00043 {
00044 bestIndex = i;
00045 bestProb = probs[i];
00046 }
00047 i++;
00048 }
00049 return bestIndex;
00050 }
00051 }
00052
00053
00054
00055 void clsfy_classifier_base::classify_many(vcl_vector<unsigned> &outputs, mbl_data_wrapper<vnl_vector<double> > &inputs) const
00056 {
00057 outputs.resize(inputs.size());
00058
00059 inputs.reset();
00060 unsigned i=0;
00061
00062 do
00063 {
00064 outputs[i++] = classify(inputs.current());
00065 } while (inputs.next());
00066 }
00067
00068
00069
00070 vcl_string clsfy_classifier_base::is_a() const
00071 {
00072 return vcl_string("clsfy_classifier_base");
00073 }
00074
00075
00076
00077 bool clsfy_classifier_base::is_class(vcl_string const& s) const
00078 {
00079 return s == clsfy_classifier_base::is_a();
00080 }
00081
00082
00083
00084 vcl_ostream& operator<<(vcl_ostream& os, clsfy_classifier_base const& b)
00085 {
00086 os << b.is_a() << ": ";
00087 vsl_indent_inc(os);
00088 b.print_summary(os);
00089 vsl_indent_dec(os);
00090 return os;
00091 }
00092
00093
00094
00095 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_classifier_base* b)
00096 {
00097 if (b)
00098 return os << *b;
00099 else
00100 return os << vsl_indent() << "No clsfy_classifier_base defined.";
00101 }
00102
00103
00104
00105 void vsl_add_to_binary_loader(const clsfy_classifier_base& b)
00106 {
00107 vsl_binary_loader<clsfy_classifier_base>::instance().add(b);
00108 }
00109
00110
00111
00112 void vsl_b_write(vsl_b_ostream& os, const clsfy_classifier_base& b)
00113 {
00114 b.b_write(os);
00115 }
00116
00117
00118
00119 void vsl_b_read(vsl_b_istream& bfs, clsfy_classifier_base& b)
00120 {
00121 b.b_read(bfs);
00122 }
00123
00124
00125
00126 double clsfy_test_error(const clsfy_classifier_base &classifier,
00127 mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00128 const vcl_vector<unsigned> & test_outputs)
00129 {
00130 assert(test_inputs.size() == test_outputs.size());
00131 if (test_inputs.size()==0) return -1;
00132
00133 vcl_vector<unsigned> results;
00134 classifier.classify_many(results, test_inputs);
00135 unsigned sum_diff = 0;
00136 const unsigned n = results.size();
00137 for (unsigned i=0; i < n; ++i)
00138 if (results[i] != test_outputs[i]) sum_diff++;
00139 return ((double) sum_diff) / ((double) n);
00140 }
00141
00142
00143
00144
00145 double clsfy_test_error(const clsfy_classifier_base &classifier,
00146 mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00147 const vcl_vector<unsigned> & test_outputs,
00148 unsigned test_class)
00149 {
00150 assert(test_inputs.size() == test_outputs.size());
00151 if (test_inputs.size()==0) return -1;
00152 test_inputs.reset();
00153 unsigned n_class=0, n_bad=0, i=0;
00154 do
00155 {
00156 if (test_outputs[i] == test_class)
00157 {
00158 if (test_outputs[i] != classifier.classify(test_inputs.current()))
00159 n_bad ++;
00160 n_class ++;
00161 }
00162 i++;
00163 } while (test_inputs.next());
00164
00165 if (n_class==0) return -1.0;
00166 return ((double) n_bad) / ((double) n_class);
00167 }
00168