contrib/mul/clsfy/clsfy_random_classifier.cxx
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_random_classifier.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 // Copyright (c) 2001: British Telecommunications plc
00006 
00007 //:
00008 // \file
00009 // \brief  Implement a random classifier
00010 // \author iscott
00011 // \date   Tue Oct  9 10:21:59 2001
00012 
00013 #include "clsfy_random_classifier.h"
00014 
00015 #include <vcl_iostream.h>
00016 #include <vcl_string.h>
00017 #include <vcl_cassert.h>
00018 #include <vsl/vsl_binary_loader.h>
00019 #include <clsfy/clsfy_classifier_base.h>
00020 #include <vsl/vsl_vector_io.h>
00021 #include <vnl/vnl_math.h>
00022 
00023 //=======================================================================
00024 
00025 clsfy_random_classifier::clsfy_random_classifier():
00026 confidence_(0.0), n_dims_(0u)
00027 {
00028 }
00029 
00030 //=======================================================================
00031 
00032 vcl_string clsfy_random_classifier::is_a() const
00033 {
00034   return vcl_string("clsfy_random_classifier");
00035 }
00036 
00037 //=======================================================================
00038 
00039 bool clsfy_random_classifier::is_class(vcl_string const& s) const
00040 {
00041   return s == clsfy_random_classifier::is_a() || clsfy_classifier_base::is_class(s);
00042 }
00043 
00044 //=======================================================================
00045 
00046 clsfy_classifier_base* clsfy_random_classifier::clone() const
00047 {
00048   return new clsfy_random_classifier(*this);
00049 }
00050 
00051 //=======================================================================
00052 
00053     // required if data is present in this base class
00054 void clsfy_random_classifier::print_summary(vcl_ostream& os) const
00055 {
00056   os << "Prior probs = "; vsl_print_summary(os, probs_);
00057   os << ", confidence = " << confidence_<<'\n';
00058 }
00059 
00060 //=======================================================================
00061 
00062 static short version_no = 1;
00063 
00064   // required if data is present in this base class
00065 void clsfy_random_classifier::b_write(vsl_b_ostream& bfs) const
00066 {
00067   vsl_b_write(bfs, version_no);
00068   vsl_b_write(bfs, probs_);
00069   vsl_b_write(bfs, confidence_);
00070   vsl_b_write(bfs, n_dims_);
00071 }
00072 
00073 //=======================================================================
00074 
00075   // required if data is present in this base class
00076 void clsfy_random_classifier::b_read(vsl_b_istream& bfs)
00077 {
00078   if (!bfs) return;
00079 
00080   short version;
00081   vsl_b_read(bfs, version);
00082   switch (version)
00083   {
00084   case (1):
00085     vsl_b_read(bfs, probs_);
00086     calc_min_to_win();
00087     vsl_b_read(bfs, confidence_);
00088     vsl_b_read(bfs, n_dims_);
00089     break;
00090   default:
00091     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_random_classifier&)\n"
00092              << "           Unknown version number "<< version << "\n";
00093     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00094   }
00095 }
00096 
00097 //=======================================================================
00098 
00099 double clsfy_random_classifier::confidence() const
00100 {
00101   return confidence_;
00102 }
00103 
00104 //=======================================================================
00105 
00106 void clsfy_random_classifier::set_confidence(double confidence)
00107 {
00108   assert(confidence >= 0.0);
00109   confidence_ = confidence;
00110 }
00111 
00112 //=======================================================================
00113 
00114 const vcl_vector<double> & clsfy_random_classifier::probs() const
00115 {
00116   return probs_;
00117 }
00118 
00119 //=======================================================================
00120 
00121 void clsfy_random_classifier::calc_min_to_win()
00122 {
00123   const unsigned n = probs_.size();
00124   min_to_win_.resize(n);
00125   for (unsigned i=0; i<n; ++i)
00126   {
00127     double maxval = -1;
00128     for (unsigned j=0; j<n; ++j)
00129     {
00130       if (j==i) continue;
00131       if (probs_[j] > maxval)
00132         maxval = probs_[j];
00133     }
00134     min_to_win_[i] = maxval - probs_[i] + vnl_math::sqrteps;
00135   }
00136 }
00137 
00138 //=======================================================================
00139 typedef vnl_c_vector<double> cvd;
00140 
00141 void clsfy_random_classifier::set_probs(const vcl_vector<double> &probs)
00142 {
00143   probs_ = probs;
00144   const unsigned n = probs_.size();
00145   assert(n > 1);
00146 
00147   double * const p = &probs_.front();
00148 
00149   cvd::scale(p, p, n, 1.0/cvd::sum(p, n));
00150 
00151   calc_min_to_win();
00152 }
00153 
00154 //=======================================================================
00155 
00156 void clsfy_random_classifier::set_n_dims(unsigned n_dims)
00157 {
00158   n_dims_ = n_dims;
00159 }
00160 
00161 //=======================================================================
00162 
00163 //: The dimensionality of input vectors.
00164 unsigned clsfy_random_classifier::n_dims() const
00165 {
00166   return n_dims_;
00167 }
00168 
00169 //=======================================================================
00170 
00171 //: The number of possible output classes.
00172 unsigned clsfy_random_classifier::n_classes() const
00173 {
00174   return probs_.size()==2?1:probs_.size();
00175 }
00176 
00177 //=======================================================================
00178 
00179 //: Return the probability the input being in each class.
00180 // output(i) i<nClasses, contains the probability that the input is in class i
00181 void clsfy_random_classifier::class_probabilities(vcl_vector<double> &outputs, const vnl_vector<double> &input) const
00182 {
00183   if (last_inputs_ != input)
00184   {
00185     last_inputs_ = input;
00186     unsigned i=0;
00187     double x=rng_.drand64() -probs_[0];
00188     while (x>=0)
00189       x-= probs_[++i];
00190 
00191     const unsigned n = probs_.size();
00192     assert(i<n);
00193 
00194     last_outputs_ = probs_;
00195     last_outputs_[i] += min_to_win_[i] + vnl_math_abs(rng_.normal()) * confidence_;
00196 
00197     double * const p = &last_outputs_[0];
00198     cvd::scale(p, p, n, 1.0/cvd::sum(p, n));
00199   }
00200 
00201 // Convert a two-class output into a binary output
00202   if (last_outputs_.size() == 2)
00203   {
00204     outputs.resize(1);
00205     outputs[0] = last_outputs_[1];
00206   }
00207   else
00208     outputs = last_outputs_;
00209 }
00210 
00211 //=======================================================================
00212 
00213 //: Log likelihood of being in class (binary classifiers only)
00214 // class probability = 1 / (1+exp(-log_l))
00215 // Operation of this method is undefined for multiclass classifiers.
00216 double clsfy_random_classifier::log_l(const vnl_vector<double> &input) const
00217 {
00218   assert (n_classes() == 1);
00219   vcl_vector<double> prob(1);
00220   class_probabilities(prob, input);
00221   return vcl_log(prob[0]/(1-prob[0]));
00222 }
00223 
00224 //=======================================================================
00225 
00226 void clsfy_random_classifier::reseed(unsigned long seed)
00227 {
00228   rng_.reseed(seed);
00229 }