contrib/mul/clsfy/clsfy_k_nearest_neighbour.cxx
Go to the documentation of this file.
00001 //  Copyright: (C) 2000 British Telecommunications plc
00002 
00003 //:
00004 // \file
00005 
00006 #include "clsfy_k_nearest_neighbour.h"
00007 
00008 #include <vcl_string.h>
00009 #include <vcl_utility.h>
00010 #include <vcl_cassert.h>
00011 
00012 #include <vnl/vnl_math.h>
00013 #include <vsl/vsl_binary_io.h>
00014 #include <vsl/vsl_vector_io.h>
00015 
00016 #include <mbl/mbl_priority_bounded_queue.h>
00017 
00018 //: Set the training data.
00019 void clsfy_k_nearest_neighbour::set(const vcl_vector<vnl_vector<double> > &inputs,
00020                                     const vcl_vector<unsigned> &outputs)
00021 {
00022   assert(inputs.size() == outputs.size());
00023   trainInputs_ = inputs;
00024   trainOutputs_ = outputs;
00025 }
00026 
00027 // stuff to get the priority queue to work happily
00028 typedef vcl_pair<double, unsigned> pairDV;
00029 struct first_lt { bool operator()(const pairDV &x, const pairDV &y)
00030 { return x.first < y.first;} };
00031 struct second_eq_one { bool operator()(const pairDV &x) {return x.second == 1;}};
00032 
00033 //: Return the classification of the given probe vector.
00034 unsigned clsfy_k_nearest_neighbour::classify(const vnl_vector<double> &input) const
00035 {
00036   const unsigned nTrainingVecs = trainInputs_.size();
00037   const unsigned k = vnl_math_min(k_, nTrainingVecs-1 + (nTrainingVecs%2));
00038   mbl_priority_bounded_queue<pairDV, vcl_vector<pairDV>, first_lt >  pq(k);
00039   unsigned i;
00040 
00041   for (i = 0; i < nTrainingVecs; i++)
00042     pq.push(vcl_make_pair(vnl_vector_ssd(input, trainInputs_[i]), trainOutputs_[i]));
00043 
00044   unsigned count = 0;
00045   for (i = 0; i < k; i++)
00046   {
00047     count += pq.top().second;
00048     pq.pop();
00049   }
00050   return count *2 > k;
00051 }
00052 
00053 
00054 //: Return a probability like value that the input being in each class.
00055 // output(i) i<<nClasses, contains the probability that the input
00056 // is in class i;
00057 void clsfy_k_nearest_neighbour::class_probabilities(vcl_vector<double> &outputs,
00058                                                     const vnl_vector<double> &input) const
00059 {
00060   const unsigned nTrainingVecs = trainInputs_.size();
00061   const unsigned k = vnl_math_min(k_, nTrainingVecs-1 + (nTrainingVecs%2));
00062   mbl_priority_bounded_queue<pairDV, vcl_vector<pairDV>, first_lt >  pq(k);
00063   unsigned i;
00064 
00065   for (i = 0; i < nTrainingVecs; i++)
00066     pq.push(vcl_make_pair(vnl_vector_ssd(input, trainInputs_[i]), trainOutputs_[i]));
00067 
00068   unsigned count = 0;
00069   for (i = 0; i < k; i++)
00070   {
00071     count += pq.top().second;
00072     pq.pop();
00073   }
00074   outputs.resize(1);
00075   outputs[0] = ((double)count)/ ((double) k);
00076 }
00077 
00078 //: The dimensionality of input vectors.
00079 unsigned clsfy_k_nearest_neighbour::n_dims() const
00080 {
00081   if (trainInputs_.size() == 0)
00082     return 0;
00083   else
00084     return trainInputs_[0].size();
00085 }
00086 
00087 
00088 //=======================================================================
00089 
00090 //: This value has properties of a Log likelihood of being in class (binary classifiers only)
00091 // class probability = exp(logL) / (1+exp(logL))
00092 double clsfy_k_nearest_neighbour::log_l(const vnl_vector<double> &input) const
00093 {
00094   vcl_vector<double> outputs(1);
00095   class_probabilities(outputs, input);
00096   double prob = outputs[0];
00097   return vcl_log(prob/(1-prob));
00098 }
00099 
00100 //=======================================================================
00101 
00102 vcl_string clsfy_k_nearest_neighbour::is_a() const
00103 {
00104   return vcl_string("clsfy_k_nearest_neighbour");
00105 }
00106 
00107 //=======================================================================
00108 
00109 bool clsfy_k_nearest_neighbour::is_class(vcl_string const& s) const
00110 {
00111   return s == clsfy_k_nearest_neighbour::is_a() || clsfy_classifier_base::is_class(s);
00112 }
00113 
00114 //=======================================================================
00115 
00116 short clsfy_k_nearest_neighbour::version_no() const
00117 {
00118   return 1;
00119 }
00120 
00121 //=======================================================================
00122 
00123 clsfy_classifier_base* clsfy_k_nearest_neighbour::clone() const
00124 {
00125   return new clsfy_k_nearest_neighbour(*this);
00126 }
00127 
00128 //=======================================================================
00129 
00130 void clsfy_k_nearest_neighbour::print_summary(vcl_ostream& os) const
00131 {
00132   os << trainInputs_.size() << " training samples, k=" << k_;
00133 }
00134 
00135 //=======================================================================
00136 
00137 void clsfy_k_nearest_neighbour::b_write(vsl_b_ostream& bfs) const
00138 {
00139   vsl_b_write(bfs,version_no());
00140   vsl_b_write(bfs,k_);
00141   vsl_b_write(bfs,trainOutputs_);
00142   vsl_b_write(bfs,trainInputs_);
00143 }
00144 
00145 //=======================================================================
00146 
00147 void clsfy_k_nearest_neighbour::b_read(vsl_b_istream& bfs)
00148 {
00149   if (!bfs) return;
00150 
00151   short version;
00152   vsl_b_read(bfs,version);
00153   switch (version)
00154   {
00155   case (1):
00156     vsl_b_read(bfs,k_);
00157     vsl_b_read(bfs,trainOutputs_);
00158     vsl_b_read(bfs,trainInputs_);
00159     break;
00160   default:
00161     vcl_cerr << "I/O ERROR: clsfy_k_nearest_neighbour::b_read(vsl_b_istream&)\n";
00162     vcl_cerr << "           Unknown version number "<< version << "\n";
00163     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00164     return;
00165   }
00166 }