contrib/mul/mbl/mbl_rbf_network.cxx
Go to the documentation of this file.
00001 // This is mul/mbl/mbl_rbf_network.cxx
00002 #include "mbl_rbf_network.h"
00003 //:
00004 // \file
00005 // \brief A class to perform some of the functions of a Radial Basis Function Network.
00006 // \author Tim Cootes
00007 //
00008 //  Given a set of n training vectors, x_i (i=0..n-1), a set of internal weights are computed.
00009 //  Given a new vector, x, a vector of weights, w, are computed such that
00010 //  if x = x_i then w(i+1) = 1, w(j !=i+1) = 0  The sum of the weights
00011 //  should always be unity.
00012 //  If x is not equal to any training vector, the vector of weights varies
00013 //  smoothly.  This is useful for interpolation purposes.
00014 //  It can also be used to define non-linear transformations between
00015 //  vector spaces.  If Y is a matrix of n columns, each corresponding to
00016 //  a vector in a new space which corresponds to one of the original
00017 //  training vectors x_i, then a vector x can be mapped to Yw in the
00018 //  new space.  (Note: y-space does not have to have the same dimension
00019 //  as x space). This class is equivalent to
00020 //  the basis of thin-plate spline warping.
00021 //
00022 //  I'm not sure if this is exactly an RBF network in the original
00023 //  definition. I'll check one day.
00024 
00025 #include <vcl_cstdlib.h>
00026 #include <vcl_cassert.h>
00027 #include <vsl/vsl_indent.h>
00028 #include <mbl/mbl_stats_1d.h>
00029 #include <vnl/algo/vnl_svd.h>
00030 #include <mbl/mbl_matxvec.h>
00031 #include <vnl/io/vnl_io_vector.h>
00032 #include <vsl/vsl_vector_io.h>
00033 
00034 //=======================================================================
00035 // Dflt ctor
00036 //=======================================================================
00037 
00038 mbl_rbf_network::mbl_rbf_network()
00039 {
00040   sum_to_one_ = true;
00041 }
00042 
00043 //: Build weights given examples x.
00044 //  s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s)
00045 //  If s<=0 then a suitable s is estimated from the data
00046 void mbl_rbf_network::build(const vcl_vector<vnl_vector<double> >& x, double s)
00047 {
00048   int n = x.size();
00049   build(&(x.front()),n,s);
00050 }
00051 
00052 //: Build weights given n examples x[0] to x[n-1].
00053 //  s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s)
00054 //  If s<=0 then a suitable s is estimated from the data
00055 void mbl_rbf_network::build(const vnl_vector<double>* x, int n, double s)
00056 {
00057   assert (n>0);
00058   // Copy training examples
00059   if (x_.size()!=(unsigned)n) x_.resize(n);
00060   for (int i=0;i<n;++i)
00061     x_[i] = x[i];
00062 
00063   // Compute distances
00064   vnl_matrix<double> D(n,n);
00065   double **D_data = D.data_array();
00066 
00067   mbl_stats_1d d2_stats;
00068 
00069   for (int i=0;i<n;++i)
00070     D(i,i)=0.0;
00071 
00072   for (int i=0;i<n-1;++i)
00073     for (int j=i+1;j<n;++j)
00074     {
00075       double d2 = distSqr(x_[i],x_[j]);
00076       D_data[i][j] = d2;
00077       D_data[j][i] = d2;
00078       d2_stats.obs(d2);
00079     }
00080 
00081   if (s<=0)
00082   {
00083     s2_ = d2_stats.min();
00084   }
00085   else
00086     s2_ = s*s;
00087 
00088   // Apply rbf() to elements of D
00089   for (int i=0;i<n;++i)
00090     for (int j=0;j<n;++j)
00091       D_data[i][j] = rbf(D_data[i][j]/s2_);
00092 
00093   // W_ is the inverse of D
00094   vnl_svd<double> svd(D);
00095   W_ = svd.inverse();
00096 }
00097 
00098 double mbl_rbf_network::distSqr(const vnl_vector<double>& x, const vnl_vector<double>& y) const
00099 {
00100   unsigned int n = x.size();
00101   if (y.size()!=n)
00102   {
00103     vcl_cerr<<"mbl_rbf_network::distSqr() x and y different sizes.\n";
00104     vcl_abort();
00105   }
00106 
00107   const double *x_data = x.begin();
00108   const double *y_data = y.begin();
00109   double sum = 0.0;
00110   for (unsigned int i=0;i<n;++i)
00111   {
00112     double dx = x_data[i]-y_data[i];
00113     sum += dx*dx;
00114   }
00115 
00116   return sum;
00117 }
00118 
00119 //: Set flag.  If false, calcWts returns raw weights
00120 void mbl_rbf_network::setSumToOne(bool flag)
00121 {
00122   sum_to_one_ = flag;
00123 }
00124 
00125 
00126 //: Compute weights for given new_x.
00127 //  If new_x = x()(i) then w(i+1)==1, w(j!=i+1)==0
00128 //  Otherwise w varies smoothly depending on distance
00129 //  of new_x from x()'s
00130 //  If sumToOne() then elements of w will sum to 1.0
00131 //  otherwise they will sum to <=1.0, decreasing as new_x
00132 //  moves away from the training examples x().
00133 void mbl_rbf_network::calcWts(vnl_vector<double>& w, const vnl_vector<double>& new_x)
00134 {
00135   unsigned int n = x_.size();
00136   if (w.size()!=n) w.set_size(n);
00137   if (v_.size()!=n) v_.set_size(n);
00138 
00139   double* v_data = &v_[0];
00140   const vnl_vector<double>* x_data = &x_[0];
00141 
00142   if (n==1)
00143   {
00144     w(0)=1.0;
00145     return;
00146   }
00147 
00148   if (n==2)
00149   {
00150     // Use linear interpolation based on distance.
00151     double d0 = vcl_sqrt(distSqr(new_x,x_data[0]));
00152     double d1 = vcl_sqrt(distSqr(new_x,x_data[1]));
00153     w(0) = d1/(d0+d1);
00154     w(1) = 1.0 - w(0);
00155     return;
00156   }
00157 
00158   for (unsigned int i=0;i<n;++i)
00159   {
00160     v_data[i] = rbf(new_x,x_data[i]);
00161   }
00162 
00163   mbl_matxvec_prod_mv(W_,v_,w);
00164 
00165   if (sum_to_one_)
00166   {
00167     double sum = w.sum();
00168     if (sum!=0) w/=sum;
00169   }
00170 }
00171 
00172 //=======================================================================
00173 // Method: version_no
00174 //=======================================================================
00175 
00176 short mbl_rbf_network::version_no() const
00177 {
00178   return 1;
00179 }
00180 
00181 //=======================================================================
00182 // Method: is_a
00183 //=======================================================================
00184 
00185 vcl_string mbl_rbf_network::is_a() const
00186 {
00187   return vcl_string("mbl_rbf_network");
00188 }
00189 
00190 //=======================================================================
00191 // Method: is_class
00192 //=======================================================================
00193 
00194 bool mbl_rbf_network::is_class(vcl_string const& s) const
00195 {
00196   return s==is_a();
00197 }
00198 
00199 //=======================================================================
00200 // Method: print
00201 //=======================================================================
00202 
00203 // required if data is present in this class
00204 void mbl_rbf_network::print_summary(vcl_ostream& os) const
00205 {
00206   os << "Built with "<<x_.size()<<" examples.";
00207   //  os << x_ << '\n' << W_ << '\n' << s2_<< '\n';
00208 }
00209 
00210 //=======================================================================
00211 // Method: save
00212 //=======================================================================
00213 
00214 // required if data is present in this class
00215 void mbl_rbf_network::b_write(vsl_b_ostream& bfs) const
00216 {
00217   vsl_b_write(bfs,version_no());
00218   vsl_b_write(bfs,x_);
00219   vsl_b_write(bfs,W_);
00220   vsl_b_write(bfs,s2_);
00221 
00222   if (sum_to_one_)
00223     vsl_b_write(bfs,short(1));
00224   else
00225     vsl_b_write(bfs,short(0));
00226 }
00227 
00228 //=======================================================================
00229 // Method: load
00230 //=======================================================================
00231 
00232 // required if data is present in this class
00233 void mbl_rbf_network::b_read(vsl_b_istream& bfs)
00234 {
00235   if (!bfs) return;
00236 
00237   short version;
00238   short flag;
00239   vsl_b_read(bfs,version);
00240   switch (version)
00241   {
00242   case (1):
00243     vsl_b_read(bfs,x_);
00244     vsl_b_read(bfs,W_);
00245     vsl_b_read(bfs,s2_);
00246     vsl_b_read(bfs,flag);  sum_to_one_ = (flag!=0);
00247     break;
00248   default:
00249     vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_rbf_network &)\n"
00250              << "           Unknown version number "<< version << vcl_endl;
00251     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00252     return;
00253   }
00254 }
00255 
00256 
00257 //=======================================================================
00258 // Associated function: operator<<
00259 //=======================================================================
00260 
00261 void vsl_b_write(vsl_b_ostream& bfs, const mbl_rbf_network& b)
00262 {
00263   b.b_write(bfs);
00264 }
00265 
00266 //=======================================================================
00267 // Associated function: operator>>
00268 //=======================================================================
00269 
00270 void vsl_b_read(vsl_b_istream& bfs, mbl_rbf_network& b)
00271 {
00272   b.b_read(bfs);
00273 }
00274 
00275 //=======================================================================
00276 // Associated function: operator<<
00277 //=======================================================================
00278 
00279 vcl_ostream& operator<<(vcl_ostream& os,const mbl_rbf_network& b)
00280 {
00281   os << b.is_a() << ": ";
00282   vsl_indent_inc(os);
00283   b.print_summary(os);
00284   vsl_indent_dec(os);
00285   return os;
00286 }
00287