contrib/mul/mbl/mbl_rbf_network.h
Go to the documentation of this file.
00001 #ifndef mbl_rbf_network_h_
00002 #define mbl_rbf_network_h_
00003 //:
00004 // \file
00005 // \brief A class to perform some of the functions of a Radial Basis Function Network
00006 // \author tfc
00007 //         wondrous VXL conversion started by gvw, errors corrected by ...
00008 
00009 #include <vsl/vsl_binary_io.h>
00010 #include <vnl/io/vnl_io_vector.h>
00011 #include <vnl/io/vnl_io_matrix.h>
00012 #include <vnl/vnl_vector.h>
00013 #include <vcl_string.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_iosfwd.h>
00017 
00018 //: A class to perform some of the functions of a Radial Basis Function Network.
00019 //  This is a special case of a mixture model pdf, where the same
00020 //  (radially symmetric) pdf kernel is used at each node.
00021 //  The nodes are supplied by build().
00022 //  calcWts(w,x) calculates the probabilities that x belongs to each
00023 //  node.
00024 //  Given a set of n training vectors, x_i (i=0..n-1), a set of internal weights are computed.
00025 //  Given a new vector, x, a vector of weights, w, are computed such that
00026 //  if x = x_i then w(i+1) = 1, w(j !=i+1) = 0  The sum of the weights
00027 //  should always be unity.
00028 //  If x is not equal to any training vector, the vector of weights varies
00029 //  smoothly.  This is useful for interpolation purposes.
00030 //  It can also be used to define non-linear transformations between
00031 //  vector spaces.  If Y is a matrix of n columns, each corresponding to
00032 //  a vector in a new space which corresponds to one of the original
00033 //  training vectors x_i, then a vector x can be mapped to Yw in the
00034 //  new space.  (Note: y-space does not have to have the same dimension
00035 //  as x space). This class is equivalent to
00036 //  the basis of thin-plate spline warping.
00037 //
00038 //  I'm not sure if this is exactly an RBF network in the original
00039 //  definition. I'll check one day.
00040 class mbl_rbf_network
00041 {
00042   vcl_vector<vnl_vector<double> > x_;
00043   vnl_matrix<double> W_;
00044   double s2_;
00045 
00046   bool sum_to_one_;
00047 
00048   //: workspace
00049   vnl_vector<double> v_;
00050 
00051   double distSqr(const vnl_vector<double>& x, const vnl_vector<double>& y) const;
00052   double rbf(double r2) const
00053     { return r2<=0.0 ? 1.0 : vcl_exp(-r2); }
00054 
00055   double rbf(const vnl_vector<double>& x, const vnl_vector<double>& y)
00056     { return rbf(distSqr(x,y)/s2_); }
00057 
00058  public:
00059 
00060   //: Dflt ctor
00061   mbl_rbf_network();
00062 
00063   //: Build weights given examples x.
00064   //  s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s)
00065   //  If s<=0 then a suitable s is estimated from the data
00066   void build(const vcl_vector<vnl_vector<double> >& x, double s = -1);
00067 
00068   //: Build weights given n examples x[0] to x[n-1].
00069   //  s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s)
00070   //  If s<=0 then a suitable s is estimated from the data
00071   void build(const vnl_vector<double>* x, int n, double s = -1);
00072 
00073   //: If true, then the returned weights sum to 1.0
00074   bool sumToOne() const { return sum_to_one_; }
00075 
00076   //: Set flag.  If false, calcWts returns raw weights
00077   void setSumToOne(bool flag);
00078 
00079   //: Array of training vectors x, supplied in last build()
00080   const vcl_vector<vnl_vector<double> >& x() const { return x_;}
00081 
00082   //: Compute weights for given new_x.
00083   //  If new_x = x()(i) then w(i+1)==1, w(j!=i+1)==0
00084   //  Otherwise w varies smoothly depending on distance
00085   //  of new_x from x()'s
00086   //  If sumToOne() then elements of w will sum to 1.0
00087   //  otherwise they will sum to <=1.0, decreasing as new_x
00088   //  moves away from the training examples x().
00089   void calcWts(vnl_vector<double>& w, const vnl_vector<double>& new_x);
00090 
00091   //: Version number for I/O
00092   short version_no() const;
00093 
00094   //: Name of the class
00095   vcl_string is_a() const;
00096 
00097   //: True if this is (or is derived from) class named s
00098   bool is_class(vcl_string const& s) const;
00099 
00100   //: Print class to os
00101   void print_summary(vcl_ostream& os) const;
00102 
00103   //: Save class to binary file stream
00104   void b_write(vsl_b_ostream& bfs) const;
00105 
00106   //: Load class from binary file stream
00107   void b_read(vsl_b_istream& bfs);
00108 };
00109 
00110 //: Binary file stream output operator for class reference
00111 void vsl_b_write(vsl_b_ostream& bfs, const mbl_rbf_network& b);
00112 
00113 //: Binary file stream input operator for class reference
00114 void vsl_b_read(vsl_b_istream& bfs, mbl_rbf_network& b);
00115 
00116 //: Stream output operator for class reference
00117 vcl_ostream& operator<<(vcl_ostream& os,const mbl_rbf_network& b);
00118 
00119 #endif //mbl_rbf_network_h_