contrib/mul/clsfy/clsfy_logit_loss_function.h
Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_logit_loss_function.h
00002 #ifndef clsfy_logit_loss_function_h_
00003 #define clsfy_logit_loss_function_h_
00004 //:
00005 // \file
00006 // \brief Loss function for logit of linear classifier
00007 // \author TFC
00008 
00009 #include <vnl/vnl_cost_function.h>
00010 #include <mbl/mbl_data_wrapper.h>
00011 
00012 //: Loss function for logit of linear classifier.
00013 //  For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes
00014 //  r(v) - (1/n_eg)sum log[(1-minp)logit(c_i * (b+w.x_i)) + minp]
00015 //
00016 // This is the sum of log prob of correct classification (+regulariser)
00017 // which should be minimised to train the classifier.
00018 //
00019 // Note: Regularisor only important to deal with case where perfect
00020 // classification possible, where scaling v would always increase f(v).
00021 // Plausible choice of regularisor is clsfy_quad_regulariser (below)
00022 class clsfy_logit_loss_function : public vnl_cost_function
00023 {
00024 private:
00025   mbl_data_wrapper<vnl_vector<double> >& x_;
00026 
00027   //: c[i] = -1 or +1, indicating class of x[i]
00028   const vnl_vector<double> & c_;
00029 
00030   //: Min probability (avoids log(zero))
00031   double min_p_;
00032 
00033   //: Optional regularising function
00034   vnl_cost_function *reg_fn_;
00035 public:
00036   clsfy_logit_loss_function(mbl_data_wrapper<vnl_vector<double> >& x,
00037                             const vnl_vector<double> & c,
00038                             double min_p, vnl_cost_function* reg_fn=0);
00039 
00040   //:  The main function: Compute f(v)
00041   virtual double f(vnl_vector<double> const& v);
00042 
00043   //:  Calculate the gradient of f at parameter vector v.
00044   virtual void gradf(vnl_vector<double> const& v, 
00045                      vnl_vector<double>& gradient);
00046 
00047   //: Compute f(v) and its gradient (if non-zero pointers supplied)
00048   virtual void compute(vnl_vector<double> const& v,
00049                        double *f, vnl_vector<double>* gradient);
00050 
00051 };
00052 
00053 //: Simple quadratic term used to regularise functions
00054 //  For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes
00055 //  f(v) = alpha*|w|^2   (ie ignores first element, which is bias of linear classifier)
00056 class clsfy_quad_regulariser : public vnl_cost_function
00057 {
00058 private:
00059   //: Scaling factor
00060   double alpha_;
00061 public:
00062   clsfy_quad_regulariser(double alpha=1e-6);
00063 
00064   //:  The main function: Compute f(v)
00065   virtual double f(vnl_vector<double> const& v);
00066 
00067   //:  Calculate the gradient of f at parameter vector v.
00068   virtual void gradf(vnl_vector<double> const& v, 
00069                      vnl_vector<double>& gradient);
00070 };
00071 
00072 
00073 #endif // clsfy_logit_loss_function_h_