contrib/mul/clsfy/clsfy_classifier_base.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_classifier_base.h
00002 // Copyright: (C) 2000 British Telecommunications plc
00003 #ifndef clsfy_classifier_base_h_
00004 #define clsfy_classifier_base_h_
00005 //:
00006 // \file
00007 // \brief Describe an abstract classifier
00008 // \author Ian Scott
00009 // \date 2000-05-10
00010 // \verbatim
00011 //  Modifications
00012 //   2 May 2001 IMS Converted to VXL
00013 // \endverbatim
00014 
00015 #include <vcl_string.h>
00016 #include <vcl_vector.h>
00017 #include <vnl/vnl_vector.h>
00018 #include <mbl/mbl_data_wrapper.h>
00019 #include <vsl/vsl_binary_io.h>
00020 #include <vcl_iostream.h>
00021 
00022 //:  A common interface for 1-out-of-N classifiers
00023 // This class takes a vector and classifies into one of
00024 // N classes.
00025 //
00026 // Derived classes with binary in the name indicates that
00027 // the classifier works with only two classes, 0 and 1.
00028 
00029 class clsfy_classifier_base
00030 {
00031  public:
00032 
00033   // Dflt constructor
00034    clsfy_classifier_base() {}
00035 
00036   // Destructor
00037    virtual ~clsfy_classifier_base() {}
00038 
00039   //: Classify the input vector
00040   // returns a number between 0 and nClasses-1 inclusive to represent the most likely class
00041   virtual unsigned classify(const vnl_vector<double> &input) const;
00042 
00043   //: Return the probability the input being in each class.
00044   // output(i) 0<=i<nClasses, contains the probability that the input is in class i
00045   virtual void class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const = 0;
00046 
00047   //: Classify many input vectors
00048   virtual void classify_many(vcl_vector<unsigned> &outputs, mbl_data_wrapper<vnl_vector<double> > &inputs) const;
00049 
00050   //: Log likelihood of being in class (binary classifiers only)
00051   // class probability = 1 / (1+exp(-log_l))
00052   // Operation of this method is undefined for multiclass classifiers
00053   virtual double log_l(const vnl_vector<double> &input) const = 0;
00054 
00055   //: The number of possible output classes.
00056   virtual unsigned n_classes() const = 0;
00057 
00058   //: The dimensionality of input vectors.
00059   virtual unsigned n_dims() const = 0;
00060 
00061   //: Name of the class
00062   virtual vcl_string is_a() const;
00063 
00064   //: Name of the class
00065   virtual bool is_class(vcl_string const& s) const;
00066 
00067   //: Create a copy on the heap and return base class pointer
00068   virtual clsfy_classifier_base* clone() const = 0;
00069 
00070   //: Print class to os
00071   virtual void print_summary(vcl_ostream& os) const = 0;
00072 
00073   //: Save class to binary file stream
00074   virtual void b_write(vsl_b_ostream& bfs) const = 0;
00075 
00076   //: Load class from binary file stream
00077   virtual void b_read(vsl_b_istream& bfs) = 0;
00078 };
00079 
00080 //: Allows derived class to be loaded by base-class pointer
00081 void vsl_add_to_binary_loader(const clsfy_classifier_base& b);
00082 
00083 //: Binary file stream output operator for class reference
00084 void vsl_b_write(vsl_b_ostream& bfs, const clsfy_classifier_base& b);
00085 
00086 //: Binary file stream input operator for class reference
00087 void vsl_b_read(vsl_b_istream& bfs, clsfy_classifier_base& b);
00088 
00089 //: Stream output operator for class reference
00090 vcl_ostream& operator<<(vcl_ostream& os, const clsfy_classifier_base& b);
00091 
00092 //: Stream output operator for class pointer
00093 vcl_ostream& operator<<(vcl_ostream& os, const clsfy_classifier_base* b);
00094 
00095 //: Stream output operator for class reference
00096 inline void vsl_print_summary(vcl_ostream& os, const clsfy_classifier_base& b)
00097 { os << b;}
00098 
00099 //: Stream output operator for class pointer
00100 inline void vsl_print_summary(vcl_ostream& os, const clsfy_classifier_base* b)
00101 { os << b;}
00102 
00103 
00104 //----------------------------------------------------------
00105 
00106 //: Calculate the fraction of test samples which are classified incorrectly
00107 double clsfy_test_error(const clsfy_classifier_base &classifier,
00108                         mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00109                         const vcl_vector<unsigned> & test_outputs);
00110 
00111 //: Calculate the fraction of test samples of a particular class which are classified incorrectly
00112 // \return -1 if there are no samples of test_class.
00113 double clsfy_test_error(const clsfy_classifier_base &classifier,
00114                         mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00115                         const vcl_vector<unsigned> & test_outputs,
00116                         unsigned test_class);
00117 
00118 
00119 #endif // clsfy_classifier_base_h_