Go to the documentation of this file.00001
00002
00003
00004
00005
00006 #include "clsfy_k_nearest_neighbour.h"
00007
00008 #include <vcl_string.h>
00009 #include <vcl_utility.h>
00010 #include <vcl_cassert.h>
00011
00012 #include <vnl/vnl_math.h>
00013 #include <vsl/vsl_binary_io.h>
00014 #include <vsl/vsl_vector_io.h>
00015
00016 #include <mbl/mbl_priority_bounded_queue.h>
00017
00018
00019 void clsfy_k_nearest_neighbour::set(const vcl_vector<vnl_vector<double> > &inputs,
00020 const vcl_vector<unsigned> &outputs)
00021 {
00022 assert(inputs.size() == outputs.size());
00023 trainInputs_ = inputs;
00024 trainOutputs_ = outputs;
00025 }
00026
00027
00028 typedef vcl_pair<double, unsigned> pairDV;
00029 struct first_lt { bool operator()(const pairDV &x, const pairDV &y)
00030 { return x.first < y.first;} };
00031 struct second_eq_one { bool operator()(const pairDV &x) {return x.second == 1;}};
00032
00033
00034 unsigned clsfy_k_nearest_neighbour::classify(const vnl_vector<double> &input) const
00035 {
00036 const unsigned nTrainingVecs = trainInputs_.size();
00037 const unsigned k = vnl_math_min(k_, nTrainingVecs-1 + (nTrainingVecs%2));
00038 mbl_priority_bounded_queue<pairDV, vcl_vector<pairDV>, first_lt > pq(k);
00039 unsigned i;
00040
00041 for (i = 0; i < nTrainingVecs; i++)
00042 pq.push(vcl_make_pair(vnl_vector_ssd(input, trainInputs_[i]), trainOutputs_[i]));
00043
00044 unsigned count = 0;
00045 for (i = 0; i < k; i++)
00046 {
00047 count += pq.top().second;
00048 pq.pop();
00049 }
00050 return count *2 > k;
00051 }
00052
00053
00054
00055
00056
00057 void clsfy_k_nearest_neighbour::class_probabilities(vcl_vector<double> &outputs,
00058 const vnl_vector<double> &input) const
00059 {
00060 const unsigned nTrainingVecs = trainInputs_.size();
00061 const unsigned k = vnl_math_min(k_, nTrainingVecs-1 + (nTrainingVecs%2));
00062 mbl_priority_bounded_queue<pairDV, vcl_vector<pairDV>, first_lt > pq(k);
00063 unsigned i;
00064
00065 for (i = 0; i < nTrainingVecs; i++)
00066 pq.push(vcl_make_pair(vnl_vector_ssd(input, trainInputs_[i]), trainOutputs_[i]));
00067
00068 unsigned count = 0;
00069 for (i = 0; i < k; i++)
00070 {
00071 count += pq.top().second;
00072 pq.pop();
00073 }
00074 outputs.resize(1);
00075 outputs[0] = ((double)count)/ ((double) k);
00076 }
00077
00078
00079 unsigned clsfy_k_nearest_neighbour::n_dims() const
00080 {
00081 if (trainInputs_.size() == 0)
00082 return 0;
00083 else
00084 return trainInputs_[0].size();
00085 }
00086
00087
00088
00089
00090
00091
00092 double clsfy_k_nearest_neighbour::log_l(const vnl_vector<double> &input) const
00093 {
00094 vcl_vector<double> outputs(1);
00095 class_probabilities(outputs, input);
00096 double prob = outputs[0];
00097 return vcl_log(prob/(1-prob));
00098 }
00099
00100
00101
00102 vcl_string clsfy_k_nearest_neighbour::is_a() const
00103 {
00104 return vcl_string("clsfy_k_nearest_neighbour");
00105 }
00106
00107
00108
00109 bool clsfy_k_nearest_neighbour::is_class(vcl_string const& s) const
00110 {
00111 return s == clsfy_k_nearest_neighbour::is_a() || clsfy_classifier_base::is_class(s);
00112 }
00113
00114
00115
00116 short clsfy_k_nearest_neighbour::version_no() const
00117 {
00118 return 1;
00119 }
00120
00121
00122
00123 clsfy_classifier_base* clsfy_k_nearest_neighbour::clone() const
00124 {
00125 return new clsfy_k_nearest_neighbour(*this);
00126 }
00127
00128
00129
00130 void clsfy_k_nearest_neighbour::print_summary(vcl_ostream& os) const
00131 {
00132 os << trainInputs_.size() << " training samples, k=" << k_;
00133 }
00134
00135
00136
00137 void clsfy_k_nearest_neighbour::b_write(vsl_b_ostream& bfs) const
00138 {
00139 vsl_b_write(bfs,version_no());
00140 vsl_b_write(bfs,k_);
00141 vsl_b_write(bfs,trainOutputs_);
00142 vsl_b_write(bfs,trainInputs_);
00143 }
00144
00145
00146
00147 void clsfy_k_nearest_neighbour::b_read(vsl_b_istream& bfs)
00148 {
00149 if (!bfs) return;
00150
00151 short version;
00152 vsl_b_read(bfs,version);
00153 switch (version)
00154 {
00155 case (1):
00156 vsl_b_read(bfs,k_);
00157 vsl_b_read(bfs,trainOutputs_);
00158 vsl_b_read(bfs,trainInputs_);
00159 break;
00160 default:
00161 vcl_cerr << "I/O ERROR: clsfy_k_nearest_neighbour::b_read(vsl_b_istream&)\n";
00162 vcl_cerr << " Unknown version number "<< version << "\n";
00163 bfs.is().clear(vcl_ios::badbit);
00164 return;
00165 }
00166 }