contrib/mul/clsfy/clsfy_classifier_1d.cxx
Go to the documentation of this file.
00001 //:
00002 // \file
00003 // \brief Describe an abstract classifier of 1D data
00004 // \author Tim Cootes
00005 
00006 #include "clsfy_classifier_1d.h"
00007 #include <vcl_iostream.h>
00008 #include <vcl_cassert.h>
00009 #include <vcl_vector.h>
00010 #include <vsl/vsl_indent.h>
00011 #include <vsl/vsl_binary_loader.h>
00012 
00013 //=======================================================================
00014 
00015 clsfy_classifier_1d::clsfy_classifier_1d()
00016 {
00017 }
00018 
00019 //=======================================================================
00020 
00021 clsfy_classifier_1d::~clsfy_classifier_1d()
00022 {
00023 }
00024 
00025 //=======================================================================
00026 
00027 unsigned clsfy_classifier_1d::classify(double input) const
00028 {
00029   unsigned N = n_classes();
00030 
00031   vcl_vector<double> probs;
00032   class_probabilities(probs, input);
00033 
00034   if (N == 1) // This is a binary classifier
00035   {
00036     if (probs[0] > 0.5)
00037       return 1u;
00038     else return 0u;
00039   }
00040   else
00041   {
00042     unsigned bestIndex = 0;
00043     unsigned i = 1;
00044     double bestProb = probs[bestIndex];
00045 
00046     while (i < N)
00047     {
00048       if (probs[i] > bestProb)
00049       {
00050         bestIndex = i;
00051         bestProb = probs[i];
00052       }
00053       i++;
00054     }
00055     return bestIndex;
00056   }
00057 }
00058 
00059 //=======================================================================
00060 
00061 void clsfy_classifier_1d::classify_many(vcl_vector<unsigned> &outputs, mbl_data_wrapper<double> &inputs) const
00062 {
00063   outputs.resize(inputs.size());
00064 
00065   inputs.reset();
00066   unsigned i=0;
00067 
00068   do
00069   {
00070     outputs[i++] = classify(inputs.current());
00071   } while (inputs.next());
00072 }
00073 
00074 //=======================================================================
00075 
00076 vcl_string clsfy_classifier_1d::is_a() const
00077 {
00078   return vcl_string("clsfy_classifier_1d");
00079 }
00080 
00081 bool clsfy_classifier_1d::is_class(vcl_string const& s) const
00082 {
00083   return s == clsfy_classifier_1d::is_a();
00084 }
00085 
00086 //=======================================================================
00087 
00088 vcl_ostream& operator<<(vcl_ostream& os, clsfy_classifier_1d const& b)
00089 {
00090   os << b.is_a() << ": ";
00091   vsl_indent_inc(os);
00092   b.print_summary(os);
00093   vsl_indent_dec(os);
00094   return os;
00095 }
00096 
00097 //=======================================================================
00098 
00099 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_classifier_1d* b)
00100 {
00101   if (b)
00102     return os << *b;
00103   else
00104     return os << vsl_indent() << "No clsfy_classifier_1d defined.";
00105 }
00106 
00107 //=======================================================================
00108 
00109 void vsl_add_to_binary_loader(const clsfy_classifier_1d& b)
00110 {
00111   vsl_binary_loader<clsfy_classifier_1d>::instance().add(b);
00112 }
00113 
00114 //=======================================================================
00115 
00116 void vsl_b_write(vsl_b_ostream& os, const clsfy_classifier_1d& b)
00117 {
00118   b.b_write(os);
00119 }
00120 
00121 //=======================================================================
00122 
00123 void vsl_b_read(vsl_b_istream& bfs, clsfy_classifier_1d& b)
00124 {
00125   b.b_read(bfs);
00126 }
00127 
00128 //=======================================================================
00129 //: Calculate the fraction of test samples which are classified incorrectly
00130 double clsfy_test_error(const clsfy_classifier_1d &classifier,
00131                         mbl_data_wrapper<double> & test_inputs,
00132                         const vcl_vector<unsigned> & test_outputs)
00133 {
00134   assert(test_inputs.size() == test_outputs.size());
00135 
00136   vcl_vector<unsigned> results;
00137   classifier.classify_many(results, test_inputs);
00138   unsigned sum_diff = 0;
00139   const unsigned n = results.size();
00140   for (unsigned i=0; i < n; ++i)
00141     if (results[i] != test_outputs[i]) sum_diff++;
00142   return ((double) sum_diff) / ((double) n);
00143 }
00144