contrib/mul/mbl/mbl_rvm_regression_builder.cxx
Go to the documentation of this file.
00001 // This is mul/mbl/mbl_rvm_regression_builder.cxx
00002 #include "mbl_rvm_regression_builder.h"
00003 //:
00004 // \file
00005 // \brief Object to train Relevance Vector Machines for regression
00006 // \author Tim Cootes
00007 
00008 #include <vcl_cmath.h>
00009 #include <vcl_algorithm.h>
00010 #include <vnl/algo/vnl_svd.h>
00011 #include <mbl/mbl_matxvec.h>
00012 #include <mbl/mbl_matrix_products.h>
00013 #include <vcl_iostream.h>
00014 #include <vcl_cassert.h>
00015 
00016 //=======================================================================
00017 // Note on indexing
00018 // index[i] gives the index (from 0..n-1) of the selected vectors
00019 // The offset weight is always included.
00020 // The alpha's thus only apply to the n vector weights, not the offset
00021 //=======================================================================
00022 // Dflt ctor
00023 //=======================================================================
00024 
00025 mbl_rvm_regression_builder::mbl_rvm_regression_builder()
00026 {
00027 }
00028 
00029 //=======================================================================
00030 // Destructor
00031 //=======================================================================
00032 
00033 mbl_rvm_regression_builder::~mbl_rvm_regression_builder()
00034 {
00035 }
00036 
00037 //: Compute design matrix F from subset of elements in kernel matrix
00038 void mbl_rvm_regression_builder::design_matrix(const vnl_matrix<double>& K,
00039                                                const vcl_vector<int>& index,
00040                                                vnl_matrix<double>& F)
00041 {
00042   unsigned n=index.size();
00043   unsigned ns=K.rows();
00044   F.set_size(ns,n+1);
00045   for (unsigned i=0;i<ns;++i)
00046   {
00047     F(i,0)=1.0;
00048     for (unsigned j=0;j<n;++j)
00049     {
00050       F(i,j+1)=K(i,index[j]);
00051     }
00052   }
00053 }
00054 
00055 //: Train RVM given a set of vectors and set of target values
00056 // Uses gaussian kernel function with variance var
00057 // \param data[i] training vectors
00058 // \param targets[i] gives value at vector i
00059 // \param index returns indices of selected vectors
00060 // \param weights returns weights for selected vectors
00061 // \param error_var returns variance term for gaussian kernel
00062 void mbl_rvm_regression_builder::gauss_build(
00063              mbl_data_wrapper<vnl_vector<double> >& data,
00064              double var, const vnl_vector<double>& targets,
00065              vcl_vector<int>& index,
00066              vnl_vector<double>& weights,
00067              double &error_var)
00068 {
00069   assert(data.size()==targets.size());
00070   unsigned n = data.size();
00071   vnl_matrix<double> K(n,n);
00072   double k = -1.0/2*var;
00073   // Construct kernel matrix
00074   for (unsigned i=1;i<n;++i)
00075   {
00076     data.set_index(i);
00077     vnl_vector<double> vi = data.current();
00078     for (unsigned j=0;j<i;++j)
00079     {
00080       data.set_index(j);
00081       double d = vcl_exp(k*vnl_vector_ssd(vi,data.current()));
00082       K(i,j)=d; K(j,i)=d;
00083     }
00084   }
00085   for (unsigned i=0;i<n;++i) K(i,i)=1.0;
00086 
00087   build(K,targets,index,weights,error_var);
00088 }
00089 
00090 //: Perform one iteration of optimisation
00091 bool mbl_rvm_regression_builder::update_step(const vnl_matrix<double>& F,
00092                                              const vnl_vector<double>& targets,
00093                                              const vcl_vector<int>& index0,
00094                                              const vcl_vector<double>& alpha0,
00095                                              double error_var0,
00096                                              vcl_vector<int>& index,
00097                                              vcl_vector<double>& alpha,
00098                                              double &error_var)
00099 {
00100   unsigned n0 = alpha0.size();
00101   assert(F.rows()==targets.size());
00102   assert(F.cols()==n0+1);
00103   vnl_matrix<double> K_inv;
00104   mbl_matrix_product_at_b(K_inv,F,F);  // K_inv=F'F
00105   K_inv/=error_var0;
00106   for (unsigned i=0;i<n0;++i) K_inv(i+1,i+1)+=alpha0[i];
00107   // K_inv = F'F/var + diag(alpha0)
00108 
00109   vnl_svd<double> svd(K_inv);
00110   S_ = svd.inverse();
00111 
00112   vnl_vector<double> t2(n0+1);
00113   mbl_matxvec_prod_vm(targets,F,t2);  // t2=F'targets  (n+1)
00114   mbl_matxvec_prod_mv(S_,t2,mean_wts_);     // mean=S*t2 (n+1)
00115   mean_wts_/=error_var0;
00116 
00117 #if 0
00118   // ---------------------
00119   // Estimate p(t|alpha,var)
00120   vnl_vector<double> a_inv(n0+1);
00121   a_inv[0]=0.0;
00122   for (unsigned i=0;i<n0;++i) a_inv[i+1]=1.0/alpha0[i];
00123   vnl_matrix<double> FAF;
00124   mbl_matrix_product_adb(FAF,F,a_inv,F.transpose());
00125   for (unsigned i=0;i<FAF.rows();++i) FAF(i,i)+=1.0/error_var0;
00126   vnl_svd<double> FAFsvd(FAF);
00127   vnl_matrix<double> FAFinv=FAFsvd.inverse();
00128   vnl_vector<double> Xt=FAFinv*targets;
00129   double M = dot_product(Xt,targets);
00130   double det=FAFsvd.determinant_magnitude();
00131   vcl_cout<<"M="<<M<<"  -log(p)="<<M+vcl_log(det)<<vcl_endl;
00132   // ---------------------
00133 #endif // 0
00134 
00135   // Compute new alphas and eliminate very large values
00136   alpha.resize(0);
00137   index.resize(0);
00138   double sum=0.0;
00139   double change=0.0;
00140   for (unsigned i=0;i<n0;++i)
00141   {
00142     double a=vcl_max(0.0,1.0-alpha0[i]*S_(i+1,i+1));
00143     sum+=a;
00144     if (vcl_fabs(mean_wts_[i+1])<1e-4) continue;
00145     double mi2 = mean_wts_[i+1]*mean_wts_[i+1];
00146     a/=mi2;
00147 
00148     if (a>1e8) continue;
00149 
00150     alpha.push_back(a);
00151     index.push_back(index0[i]);
00152     change+=vcl_fabs(a-alpha0[i]);
00153   }
00154   // Update estimate of error_var
00155   vnl_vector<double> Fm;
00156   mbl_matxvec_prod_mv(F,mean_wts_,Fm);     // Fm=F*mean
00157   double sum_sqr_error=vnl_vector_ssd(targets,Fm);
00158   error_var = sum_sqr_error/(targets.size()-sum);
00159 // vcl_cout<<"Sum sqr error = "<<sum_sqr_error<<vcl_endl;
00160   change+=vcl_fabs(error_var-error_var0);
00161 
00162   // Decide if optimisation completed
00163   if (alpha.size()!=alpha0.size()) return true;
00164   return change/n0 > 0.01;
00165 }
00166 
00167 //: Train RVM given a distance matrix and set of target values
00168 // \param kernel_matrix (i,j) element gives kernel function between i and j training vectors
00169 // \param targets[i] gives value at vector i
00170 // \param index returns indices of selected vectors
00171 // \param weights returns weights for selected vectors
00172 // \param error_var returns variance term for gaussian kernel
00173 void mbl_rvm_regression_builder::build(
00174              const vnl_matrix<double>& kernel_matrix,
00175              const vnl_vector<double>& targets,
00176              vcl_vector<int>& index,
00177              vnl_vector<double>& weights,
00178              double &error_var)
00179 {
00180   assert(kernel_matrix.rows()==targets.size());
00181   assert(kernel_matrix.cols()<=targets.size());
00182   unsigned n0=kernel_matrix.cols();
00183 
00184   // Initialise to use all n0 samples with equal weights
00185   index.resize(n0);
00186   vcl_vector<double> alpha(n0),new_alpha;
00187   vcl_vector<int> new_index;
00188   for (unsigned i=0;i<n0;++i)  { index[i]=i; alpha[i]=1e-4; }
00189   error_var = 0.01;
00190   double new_error_var;
00191 
00192   vnl_matrix<double> F;
00193   design_matrix(kernel_matrix,index,F);
00194   int max_its=500;
00195   int n_its=0;
00196   while (update_step(F,targets,index,alpha,error_var,
00197                      new_index,new_alpha,new_error_var)  && n_its<max_its)
00198   {
00199     index    = new_index;
00200     alpha    = new_alpha;
00201     error_var= new_error_var;
00202     design_matrix(kernel_matrix,index,F);
00203     n_its++;
00204   }
00205 
00206   if (n_its>=max_its)
00207     vcl_cerr<<"mbl_rvm_regression_builder::build() Too many iterations. Convergence failure.\n";
00208 
00209   weights=mean_wts_;
00210 }