00001
00002 #include "mbl_rvm_regression_builder.h"
00003
00004
00005
00006
00007
00008 #include <vcl_cmath.h>
00009 #include <vcl_algorithm.h>
00010 #include <vnl/algo/vnl_svd.h>
00011 #include <mbl/mbl_matxvec.h>
00012 #include <mbl/mbl_matrix_products.h>
00013 #include <vcl_iostream.h>
00014 #include <vcl_cassert.h>
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 mbl_rvm_regression_builder::mbl_rvm_regression_builder()
00026 {
00027 }
00028
00029
00030
00031
00032
00033 mbl_rvm_regression_builder::~mbl_rvm_regression_builder()
00034 {
00035 }
00036
00037
00038 void mbl_rvm_regression_builder::design_matrix(const vnl_matrix<double>& K,
00039 const vcl_vector<int>& index,
00040 vnl_matrix<double>& F)
00041 {
00042 unsigned n=index.size();
00043 unsigned ns=K.rows();
00044 F.set_size(ns,n+1);
00045 for (unsigned i=0;i<ns;++i)
00046 {
00047 F(i,0)=1.0;
00048 for (unsigned j=0;j<n;++j)
00049 {
00050 F(i,j+1)=K(i,index[j]);
00051 }
00052 }
00053 }
00054
00055
00056
00057
00058
00059
00060
00061
00062 void mbl_rvm_regression_builder::gauss_build(
00063 mbl_data_wrapper<vnl_vector<double> >& data,
00064 double var, const vnl_vector<double>& targets,
00065 vcl_vector<int>& index,
00066 vnl_vector<double>& weights,
00067 double &error_var)
00068 {
00069 assert(data.size()==targets.size());
00070 unsigned n = data.size();
00071 vnl_matrix<double> K(n,n);
00072 double k = -1.0/2*var;
00073
00074 for (unsigned i=1;i<n;++i)
00075 {
00076 data.set_index(i);
00077 vnl_vector<double> vi = data.current();
00078 for (unsigned j=0;j<i;++j)
00079 {
00080 data.set_index(j);
00081 double d = vcl_exp(k*vnl_vector_ssd(vi,data.current()));
00082 K(i,j)=d; K(j,i)=d;
00083 }
00084 }
00085 for (unsigned i=0;i<n;++i) K(i,i)=1.0;
00086
00087 build(K,targets,index,weights,error_var);
00088 }
00089
00090
00091 bool mbl_rvm_regression_builder::update_step(const vnl_matrix<double>& F,
00092 const vnl_vector<double>& targets,
00093 const vcl_vector<int>& index0,
00094 const vcl_vector<double>& alpha0,
00095 double error_var0,
00096 vcl_vector<int>& index,
00097 vcl_vector<double>& alpha,
00098 double &error_var)
00099 {
00100 unsigned n0 = alpha0.size();
00101 assert(F.rows()==targets.size());
00102 assert(F.cols()==n0+1);
00103 vnl_matrix<double> K_inv;
00104 mbl_matrix_product_at_b(K_inv,F,F);
00105 K_inv/=error_var0;
00106 for (unsigned i=0;i<n0;++i) K_inv(i+1,i+1)+=alpha0[i];
00107
00108
00109 vnl_svd<double> svd(K_inv);
00110 S_ = svd.inverse();
00111
00112 vnl_vector<double> t2(n0+1);
00113 mbl_matxvec_prod_vm(targets,F,t2);
00114 mbl_matxvec_prod_mv(S_,t2,mean_wts_);
00115 mean_wts_/=error_var0;
00116
00117 #if 0
00118
00119
00120 vnl_vector<double> a_inv(n0+1);
00121 a_inv[0]=0.0;
00122 for (unsigned i=0;i<n0;++i) a_inv[i+1]=1.0/alpha0[i];
00123 vnl_matrix<double> FAF;
00124 mbl_matrix_product_adb(FAF,F,a_inv,F.transpose());
00125 for (unsigned i=0;i<FAF.rows();++i) FAF(i,i)+=1.0/error_var0;
00126 vnl_svd<double> FAFsvd(FAF);
00127 vnl_matrix<double> FAFinv=FAFsvd.inverse();
00128 vnl_vector<double> Xt=FAFinv*targets;
00129 double M = dot_product(Xt,targets);
00130 double det=FAFsvd.determinant_magnitude();
00131 vcl_cout<<"M="<<M<<" -log(p)="<<M+vcl_log(det)<<vcl_endl;
00132
00133 #endif // 0
00134
00135
00136 alpha.resize(0);
00137 index.resize(0);
00138 double sum=0.0;
00139 double change=0.0;
00140 for (unsigned i=0;i<n0;++i)
00141 {
00142 double a=vcl_max(0.0,1.0-alpha0[i]*S_(i+1,i+1));
00143 sum+=a;
00144 if (vcl_fabs(mean_wts_[i+1])<1e-4) continue;
00145 double mi2 = mean_wts_[i+1]*mean_wts_[i+1];
00146 a/=mi2;
00147
00148 if (a>1e8) continue;
00149
00150 alpha.push_back(a);
00151 index.push_back(index0[i]);
00152 change+=vcl_fabs(a-alpha0[i]);
00153 }
00154
00155 vnl_vector<double> Fm;
00156 mbl_matxvec_prod_mv(F,mean_wts_,Fm);
00157 double sum_sqr_error=vnl_vector_ssd(targets,Fm);
00158 error_var = sum_sqr_error/(targets.size()-sum);
00159
00160 change+=vcl_fabs(error_var-error_var0);
00161
00162
00163 if (alpha.size()!=alpha0.size()) return true;
00164 return change/n0 > 0.01;
00165 }
00166
00167
00168
00169
00170
00171
00172
00173 void mbl_rvm_regression_builder::build(
00174 const vnl_matrix<double>& kernel_matrix,
00175 const vnl_vector<double>& targets,
00176 vcl_vector<int>& index,
00177 vnl_vector<double>& weights,
00178 double &error_var)
00179 {
00180 assert(kernel_matrix.rows()==targets.size());
00181 assert(kernel_matrix.cols()<=targets.size());
00182 unsigned n0=kernel_matrix.cols();
00183
00184
00185 index.resize(n0);
00186 vcl_vector<double> alpha(n0),new_alpha;
00187 vcl_vector<int> new_index;
00188 for (unsigned i=0;i<n0;++i) { index[i]=i; alpha[i]=1e-4; }
00189 error_var = 0.01;
00190 double new_error_var;
00191
00192 vnl_matrix<double> F;
00193 design_matrix(kernel_matrix,index,F);
00194 int max_its=500;
00195 int n_its=0;
00196 while (update_step(F,targets,index,alpha,error_var,
00197 new_index,new_alpha,new_error_var) && n_its<max_its)
00198 {
00199 index = new_index;
00200 alpha = new_alpha;
00201 error_var= new_error_var;
00202 design_matrix(kernel_matrix,index,F);
00203 n_its++;
00204 }
00205
00206 if (n_its>=max_its)
00207 vcl_cerr<<"mbl_rvm_regression_builder::build() Too many iterations. Convergence failure.\n";
00208
00209 weights=mean_wts_;
00210 }