Go to the documentation of this file.00001
00002 #include "mbl_rbf_network.h"
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #include <vcl_cstdlib.h>
00026 #include <vcl_cassert.h>
00027 #include <vsl/vsl_indent.h>
00028 #include <mbl/mbl_stats_1d.h>
00029 #include <vnl/algo/vnl_svd.h>
00030 #include <mbl/mbl_matxvec.h>
00031 #include <vnl/io/vnl_io_vector.h>
00032 #include <vsl/vsl_vector_io.h>
00033
00034
00035
00036
00037
00038 mbl_rbf_network::mbl_rbf_network()
00039 {
00040 sum_to_one_ = true;
00041 }
00042
00043
00044
00045
00046 void mbl_rbf_network::build(const vcl_vector<vnl_vector<double> >& x, double s)
00047 {
00048 int n = x.size();
00049 build(&(x.front()),n,s);
00050 }
00051
00052
00053
00054
00055 void mbl_rbf_network::build(const vnl_vector<double>* x, int n, double s)
00056 {
00057 assert (n>0);
00058
00059 if (x_.size()!=(unsigned)n) x_.resize(n);
00060 for (int i=0;i<n;++i)
00061 x_[i] = x[i];
00062
00063
00064 vnl_matrix<double> D(n,n);
00065 double **D_data = D.data_array();
00066
00067 mbl_stats_1d d2_stats;
00068
00069 for (int i=0;i<n;++i)
00070 D(i,i)=0.0;
00071
00072 for (int i=0;i<n-1;++i)
00073 for (int j=i+1;j<n;++j)
00074 {
00075 double d2 = distSqr(x_[i],x_[j]);
00076 D_data[i][j] = d2;
00077 D_data[j][i] = d2;
00078 d2_stats.obs(d2);
00079 }
00080
00081 if (s<=0)
00082 {
00083 s2_ = d2_stats.min();
00084 }
00085 else
00086 s2_ = s*s;
00087
00088
00089 for (int i=0;i<n;++i)
00090 for (int j=0;j<n;++j)
00091 D_data[i][j] = rbf(D_data[i][j]/s2_);
00092
00093
00094 vnl_svd<double> svd(D);
00095 W_ = svd.inverse();
00096 }
00097
00098 double mbl_rbf_network::distSqr(const vnl_vector<double>& x, const vnl_vector<double>& y) const
00099 {
00100 unsigned int n = x.size();
00101 if (y.size()!=n)
00102 {
00103 vcl_cerr<<"mbl_rbf_network::distSqr() x and y different sizes.\n";
00104 vcl_abort();
00105 }
00106
00107 const double *x_data = x.begin();
00108 const double *y_data = y.begin();
00109 double sum = 0.0;
00110 for (unsigned int i=0;i<n;++i)
00111 {
00112 double dx = x_data[i]-y_data[i];
00113 sum += dx*dx;
00114 }
00115
00116 return sum;
00117 }
00118
00119
00120 void mbl_rbf_network::setSumToOne(bool flag)
00121 {
00122 sum_to_one_ = flag;
00123 }
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133 void mbl_rbf_network::calcWts(vnl_vector<double>& w, const vnl_vector<double>& new_x)
00134 {
00135 unsigned int n = x_.size();
00136 if (w.size()!=n) w.set_size(n);
00137 if (v_.size()!=n) v_.set_size(n);
00138
00139 double* v_data = &v_[0];
00140 const vnl_vector<double>* x_data = &x_[0];
00141
00142 if (n==1)
00143 {
00144 w(0)=1.0;
00145 return;
00146 }
00147
00148 if (n==2)
00149 {
00150
00151 double d0 = vcl_sqrt(distSqr(new_x,x_data[0]));
00152 double d1 = vcl_sqrt(distSqr(new_x,x_data[1]));
00153 w(0) = d1/(d0+d1);
00154 w(1) = 1.0 - w(0);
00155 return;
00156 }
00157
00158 for (unsigned int i=0;i<n;++i)
00159 {
00160 v_data[i] = rbf(new_x,x_data[i]);
00161 }
00162
00163 mbl_matxvec_prod_mv(W_,v_,w);
00164
00165 if (sum_to_one_)
00166 {
00167 double sum = w.sum();
00168 if (sum!=0) w/=sum;
00169 }
00170 }
00171
00172
00173
00174
00175
00176 short mbl_rbf_network::version_no() const
00177 {
00178 return 1;
00179 }
00180
00181
00182
00183
00184
00185 vcl_string mbl_rbf_network::is_a() const
00186 {
00187 return vcl_string("mbl_rbf_network");
00188 }
00189
00190
00191
00192
00193
00194 bool mbl_rbf_network::is_class(vcl_string const& s) const
00195 {
00196 return s==is_a();
00197 }
00198
00199
00200
00201
00202
00203
00204 void mbl_rbf_network::print_summary(vcl_ostream& os) const
00205 {
00206 os << "Built with "<<x_.size()<<" examples.";
00207
00208 }
00209
00210
00211
00212
00213
00214
00215 void mbl_rbf_network::b_write(vsl_b_ostream& bfs) const
00216 {
00217 vsl_b_write(bfs,version_no());
00218 vsl_b_write(bfs,x_);
00219 vsl_b_write(bfs,W_);
00220 vsl_b_write(bfs,s2_);
00221
00222 if (sum_to_one_)
00223 vsl_b_write(bfs,short(1));
00224 else
00225 vsl_b_write(bfs,short(0));
00226 }
00227
00228
00229
00230
00231
00232
00233 void mbl_rbf_network::b_read(vsl_b_istream& bfs)
00234 {
00235 if (!bfs) return;
00236
00237 short version;
00238 short flag;
00239 vsl_b_read(bfs,version);
00240 switch (version)
00241 {
00242 case (1):
00243 vsl_b_read(bfs,x_);
00244 vsl_b_read(bfs,W_);
00245 vsl_b_read(bfs,s2_);
00246 vsl_b_read(bfs,flag); sum_to_one_ = (flag!=0);
00247 break;
00248 default:
00249 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_rbf_network &)\n"
00250 << " Unknown version number "<< version << vcl_endl;
00251 bfs.is().clear(vcl_ios::badbit);
00252 return;
00253 }
00254 }
00255
00256
00257
00258
00259
00260
00261 void vsl_b_write(vsl_b_ostream& bfs, const mbl_rbf_network& b)
00262 {
00263 b.b_write(bfs);
00264 }
00265
00266
00267
00268
00269
00270 void vsl_b_read(vsl_b_istream& bfs, mbl_rbf_network& b)
00271 {
00272 b.b_read(bfs);
00273 }
00274
00275
00276
00277
00278
00279 vcl_ostream& operator<<(vcl_ostream& os,const mbl_rbf_network& b)
00280 {
00281 os << b.is_a() << ": ";
00282 vsl_indent_inc(os);
00283 b.print_summary(os);
00284 vsl_indent_dec(os);
00285 return os;
00286 }
00287