contrib/mul/clsfy/clsfy_classifier_base.cxx
Go to the documentation of this file.
00001 // Copyright: (C) 2000 British Telecommunications plc
00002 #include "clsfy_classifier_base.h"
00003 //:
00004 // \file
00005 // \brief Implement bits of an abstract classifier
00006 // \author Ian Scott
00007 // \date 2000-05-10
00008 // \verbatim
00009 //  Modifications
00010 //   2 May 2001 IMS Converted to VXL
00011 // \endverbatim
00012 
00013 #include <vcl_iostream.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_vector.h>
00016 #include <vsl/vsl_indent.h>
00017 #include <vsl/vsl_binary_loader.h>
00018 
00019 //=======================================================================
00020 
00021 unsigned clsfy_classifier_base::classify(const vnl_vector<double> &input) const
00022 {
00023   unsigned N = n_classes();
00024 
00025   vcl_vector<double> probs;
00026   class_probabilities(probs, input);
00027 
00028   if (N == 1) // This is a binary classifier
00029   {
00030     if (probs[0] > 0.5)
00031       return 1u;
00032     else return 0u;
00033   }
00034   else
00035   {
00036     unsigned bestIndex = 0;
00037     unsigned i = 1;
00038     double bestProb = probs[bestIndex];
00039 
00040     while (i < N)
00041     {
00042       if (probs[i] > bestProb)
00043       {
00044         bestIndex = i;
00045         bestProb = probs[i];
00046       }
00047       i++;
00048     }
00049     return bestIndex;
00050   }
00051 }
00052 
00053 //=======================================================================
00054 
00055 void clsfy_classifier_base::classify_many(vcl_vector<unsigned> &outputs, mbl_data_wrapper<vnl_vector<double> > &inputs) const
00056 {
00057   outputs.resize(inputs.size());
00058 
00059   inputs.reset();
00060   unsigned i=0;
00061 
00062   do
00063   {
00064     outputs[i++] = classify(inputs.current());
00065   } while (inputs.next());
00066 }
00067 
00068 //=======================================================================
00069 
00070 vcl_string clsfy_classifier_base::is_a() const
00071 {
00072   return vcl_string("clsfy_classifier_base");
00073 }
00074 
00075 //=======================================================================
00076 
00077 bool clsfy_classifier_base::is_class(vcl_string const& s) const
00078 {
00079   return s == clsfy_classifier_base::is_a();
00080 }
00081 
00082 //=======================================================================
00083 
00084 vcl_ostream& operator<<(vcl_ostream& os, clsfy_classifier_base const& b)
00085 {
00086   os << b.is_a() << ": ";
00087   vsl_indent_inc(os);
00088   b.print_summary(os);
00089   vsl_indent_dec(os);
00090   return os;
00091 }
00092 
00093 //=======================================================================
00094 
00095 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_classifier_base* b)
00096 {
00097   if (b)
00098     return os << *b;
00099   else
00100     return os << vsl_indent() << "No clsfy_classifier_base defined.";
00101 }
00102 
00103 //=======================================================================
00104 
00105 void vsl_add_to_binary_loader(const clsfy_classifier_base& b)
00106 {
00107   vsl_binary_loader<clsfy_classifier_base>::instance().add(b);
00108 }
00109 
00110 //=======================================================================
00111 
00112 void vsl_b_write(vsl_b_ostream& os, const clsfy_classifier_base& b)
00113 {
00114   b.b_write(os);
00115 }
00116 
00117 //=======================================================================
00118 
00119 void vsl_b_read(vsl_b_istream& bfs, clsfy_classifier_base& b)
00120 {
00121   b.b_read(bfs);
00122 }
00123 
00124 //=======================================================================
00125 //: Calculate the fraction of test samples which are classified incorrectly
00126 double clsfy_test_error(const clsfy_classifier_base &classifier,
00127                         mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00128                         const vcl_vector<unsigned> & test_outputs)
00129 {
00130   assert(test_inputs.size() == test_outputs.size());
00131   if (test_inputs.size()==0) return -1;
00132 
00133   vcl_vector<unsigned> results;
00134   classifier.classify_many(results, test_inputs);
00135   unsigned sum_diff = 0;
00136   const unsigned n = results.size();
00137   for (unsigned i=0; i < n; ++i)
00138     if (results[i] != test_outputs[i]) sum_diff++;
00139   return ((double) sum_diff) / ((double) n);
00140 }
00141 
00142 //=======================================================================
00143 //: Calculate the fraction of test samples of a particular class which are classified incorrectly
00144 // \return -1 if there are no samples of test_class.
00145 double clsfy_test_error(const clsfy_classifier_base &classifier,
00146                         mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00147                         const vcl_vector<unsigned> & test_outputs,
00148                         unsigned test_class)
00149 {
00150   assert(test_inputs.size() == test_outputs.size());
00151   if (test_inputs.size()==0) return -1;
00152   test_inputs.reset();
00153   unsigned n_class=0, n_bad=0, i=0;
00154   do
00155   {
00156     if (test_outputs[i] == test_class)
00157     {
00158       if (test_outputs[i] != classifier.classify(test_inputs.current()))
00159         n_bad ++;
00160       n_class ++;
00161     }
00162     i++;
00163   } while (test_inputs.next());
00164 
00165   if (n_class==0) return -1.0;
00166   return ((double) n_bad) / ((double) n_class);
00167 }
00168