00001 // This is mul/clsfy/clsfy_logit_loss_function.h 00002 #ifndef clsfy_logit_loss_function_h_ 00003 #define clsfy_logit_loss_function_h_ 00004 //: 00005 // \file 00006 // \brief Loss function for logit of linear classifier 00007 // \author TFC 00008 00009 #include <vnl/vnl_cost_function.h> 00010 #include <mbl/mbl_data_wrapper.h> 00011 00012 //: Loss function for logit of linear classifier. 00013 // For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes 00014 // r(v) - (1/n_eg)sum log[(1-minp)logit(c_i * (b+w.x_i)) + minp] 00015 // 00016 // This is the sum of log prob of correct classification (+regulariser) 00017 // which should be minimised to train the classifier. 00018 // 00019 // Note: Regularisor only important to deal with case where perfect 00020 // classification possible, where scaling v would always increase f(v). 00021 // Plausible choice of regularisor is clsfy_quad_regulariser (below) 00022 class clsfy_logit_loss_function : public vnl_cost_function 00023 { 00024 private: 00025 mbl_data_wrapper<vnl_vector<double> >& x_; 00026 00027 //: c[i] = -1 or +1, indicating class of x[i] 00028 const vnl_vector<double> & c_; 00029 00030 //: Min probability (avoids log(zero)) 00031 double min_p_; 00032 00033 //: Optional regularising function 00034 vnl_cost_function *reg_fn_; 00035 public: 00036 clsfy_logit_loss_function(mbl_data_wrapper<vnl_vector<double> >& x, 00037 const vnl_vector<double> & c, 00038 double min_p, vnl_cost_function* reg_fn=0); 00039 00040 //: The main function: Compute f(v) 00041 virtual double f(vnl_vector<double> const& v); 00042 00043 //: Calculate the gradient of f at parameter vector v. 00044 virtual void gradf(vnl_vector<double> const& v, 00045 vnl_vector<double>& gradient); 00046 00047 //: Compute f(v) and its gradient (if non-zero pointers supplied) 00048 virtual void compute(vnl_vector<double> const& v, 00049 double *f, vnl_vector<double>* gradient); 00050 00051 }; 00052 00053 //: Simple quadratic term used to regularise functions 00054 // For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes 00055 // f(v) = alpha*|w|^2 (ie ignores first element, which is bias of linear classifier) 00056 class clsfy_quad_regulariser : public vnl_cost_function 00057 { 00058 private: 00059 //: Scaling factor 00060 double alpha_; 00061 public: 00062 clsfy_quad_regulariser(double alpha=1e-6); 00063 00064 //: The main function: Compute f(v) 00065 virtual double f(vnl_vector<double> const& v); 00066 00067 //: Calculate the gradient of f at parameter vector v. 00068 virtual void gradf(vnl_vector<double> const& v, 00069 vnl_vector<double>& gradient); 00070 }; 00071 00072 00073 #endif // clsfy_logit_loss_function_h_