contrib/mul/mfpf/mfpf_log_lin_class_cost.cxx
Go to the documentation of this file.
00001 #include "mfpf_log_lin_class_cost.h"
00002 //:
00003 // \file
00004 // \brief Computes log prob based on output of a linear classifier
00005 // \author Tim Cootes
00006 
00007 #include <vsl/vsl_binary_loader.h>
00008 #include <vcl_cassert.h>
00009 
00010 #include <vnl/io/vnl_io_vector.h>
00011 
00012 //=======================================================================
00013 // Dflt ctor
00014 //=======================================================================
00015 
00016 mfpf_log_lin_class_cost::mfpf_log_lin_class_cost()
00017 {
00018 }
00019 
00020 //=======================================================================
00021 // Destructor
00022 //=======================================================================
00023 
00024 mfpf_log_lin_class_cost::~mfpf_log_lin_class_cost()
00025 {
00026 }
00027 
00028 //: Define weights, bias and minp
00029 void mfpf_log_lin_class_cost::set(const vnl_vector<double>& wts, 
00030                                   double bias, double min_p)
00031 {
00032   wts_ =wts;
00033   bias_=bias;
00034   min_p_=min_p;
00035 }
00036 
00037 //: Returns -1*log(minp + (1-minp)/(1+exp(-(x.wts-bias)))
00038 double mfpf_log_lin_class_cost::evaluate(const vnl_vector<double>& x)
00039 {
00040   double z = bias_-dot_product(wts_,x);
00041   return -1*vcl_log(min_p_+(1-min_p_)/(1+vcl_exp(z)));
00042 }
00043 
00044 //: Return the weights
00045 void mfpf_log_lin_class_cost::get_average(vnl_vector<double>& v) const
00046 {
00047   v=wts_;
00048 }
00049 
00050 
00051 //=======================================================================
00052 // Method: version_no
00053 //=======================================================================
00054 
00055 short mfpf_log_lin_class_cost::version_no() const
00056 {
00057   return 1;
00058 }
00059 
00060 //=======================================================================
00061 // Method: is_a
00062 //=======================================================================
00063 
00064 vcl_string mfpf_log_lin_class_cost::is_a() const
00065 {
00066   return vcl_string("mfpf_log_lin_class_cost");
00067 }
00068 
00069 //: Create a copy on the heap and return base class pointer
00070 mfpf_vec_cost* mfpf_log_lin_class_cost::clone() const
00071 {
00072   return new mfpf_log_lin_class_cost(*this);
00073 }
00074 
00075 //=======================================================================
00076 // Method: print
00077 //=======================================================================
00078 
00079 void mfpf_log_lin_class_cost::print_summary(vcl_ostream& os) const
00080 {
00081   os<<"Size: "<<wts_.size();
00082 }
00083 
00084 //=======================================================================
00085 // Method: save
00086 //=======================================================================
00087 
00088 void mfpf_log_lin_class_cost::b_write(vsl_b_ostream& bfs) const
00089 {
00090   vsl_b_write(bfs,version_no());
00091   vsl_b_write(bfs,wts_);
00092   vsl_b_write(bfs,bias_);
00093   vsl_b_write(bfs,min_p_);
00094 }
00095 
00096 //=======================================================================
00097 // Method: load
00098 //=======================================================================
00099 
00100 void mfpf_log_lin_class_cost::b_read(vsl_b_istream& bfs)
00101 {
00102   if (!bfs) return;
00103   short version;
00104   vsl_b_read(bfs,version);
00105   switch (version)
00106   {
00107     case (1):
00108       vsl_b_read(bfs,wts_);
00109       vsl_b_read(bfs,bias_);
00110       vsl_b_read(bfs,min_p_);
00111       break;
00112     default:
00113       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00114                << "           Unknown version number "<< version << vcl_endl;
00115       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00116       return;
00117   }
00118 }