00001 // This is mul/mbl/mbl_rvm_regression_builder.h 00002 #ifndef mbl_rvm_regression_builder_h_ 00003 #define mbl_rvm_regression_builder_h_ 00004 //: 00005 // \file 00006 // \brief Object to train Relevance Vector Machines for regression 00007 // \author Tim Cootes 00008 00009 #include <vcl_vector.h> 00010 #include <vnl/vnl_vector.h> 00011 #include <vnl/vnl_matrix.h> 00012 #include <mbl/mbl_data_wrapper.h> 00013 00014 //======================================================================= 00015 //: Object to train Relevance Vector Machines for regression. 00016 // Trains Relevance Vector Machines (see papers by Michael Tipping) 00017 // for regression. 00018 class mbl_rvm_regression_builder 00019 { 00020 private: 00021 //: Record of mean weights 00022 vnl_vector<double> mean_wts_; 00023 00024 //: Record of covariance in weights 00025 vnl_matrix<double> S_; 00026 00027 //: Compute design matrix F from subset of elements in kernel matrix 00028 // Uses gaussian distance expression with variance var 00029 void design_matrix(const vnl_matrix<double>& kernel_matrix, 00030 const vcl_vector<int>& index, 00031 vnl_matrix<double>& F); 00032 00033 //: Perform one iteration of optimisation 00034 bool update_step(const vnl_matrix<double>& design_matrix, 00035 const vnl_vector<double>& targets, 00036 const vcl_vector<int>& index0, 00037 const vcl_vector<double>& alpha0, 00038 double sqr_width0, 00039 vcl_vector<int>& index, 00040 vcl_vector<double>& alpha, 00041 double &error_var); 00042 public: 00043 //: Dflt ctor 00044 mbl_rvm_regression_builder(); 00045 00046 //: Destructor 00047 virtual ~mbl_rvm_regression_builder(); 00048 00049 //: Train RVM given a set of vectors and set of target values 00050 // Resulting RVM has form f(x)=w[0]+sum w[i+1]K(x,data[index[i]]) 00051 // where K(x,y)=exp(-|x-y|^2/2var), and index.size() gives 00052 // the number of the selected vectors. 00053 // Note that on exit, weights.size()==index.size()+1 00054 // weights[0] is the constant offset, and weights[i+1] 00055 // corresponds to selected input vector index[i]. 00056 // \param data[i] training vectors 00057 // \param targets[i] gives value at vector i 00058 // \param index returns indices of selected vectors 00059 // \param weights returns weights for selected vectors 00060 // \param error_var returns estimated error variance for resulting function 00061 void gauss_build(mbl_data_wrapper<vnl_vector<double> >& data, 00062 double var, 00063 const vnl_vector<double>& targets, 00064 vcl_vector<int>& index, 00065 vnl_vector<double>& weights, 00066 double &error_var); 00067 00068 //: Train RVM given a distance matrix and set of target values 00069 // Resulting RVM has form f(x)=w[0]+sum w[i+1]K(x,vec[index[i]]) 00070 // where K(x,y) is the kernel function, and index.size() gives 00071 // the number of the selected vectors. 00072 // Assuming the original data is vec[i], then on input we should 00073 // have kernel_matrix(i,j)=K(vec[i],vec[j]) 00074 // Note that on exit, weights.size()==index.size()+1 00075 // weights[0] is the constant offset, and weights[i+1] 00076 // corresponds to selected input vector index[i]. 00077 // 00078 // The algorithm involves inverting an (n+1)x(n+1) matrix, 00079 // where n is the number of vectors to consider as relevant 00080 // vectors. This doesn't have to be all the samples. 00081 // For efficiency, one can provide the kernel matrix as an 00082 // m x n matrix, where m is the number of samples to test 00083 // against (targets.size()==m) but n<=m samples are to be 00084 // considered as potential relevant vectors (ie the first n). 00085 // \param kernel_matrix (i,j) element gives kernel function between i and j training vectors 00086 // \param targets[i] gives value at vector i 00087 // \param index returns indices of selected vectors 00088 // \param weights returns weights for selected vectors 00089 // \param error_var returns estimated error variance for resulting function 00090 void build(const vnl_matrix<double>& kernel_matrix, 00091 const vnl_vector<double>& targets, 00092 vcl_vector<int>& index, 00093 vnl_vector<double>& weights, 00094 double &error_var); 00095 }; 00096 00097 #endif