00001 #ifndef mbl_rbf_network_h_ 00002 #define mbl_rbf_network_h_ 00003 //: 00004 // \file 00005 // \brief A class to perform some of the functions of a Radial Basis Function Network 00006 // \author tfc 00007 // wondrous VXL conversion started by gvw, errors corrected by ... 00008 00009 #include <vsl/vsl_binary_io.h> 00010 #include <vnl/io/vnl_io_vector.h> 00011 #include <vnl/io/vnl_io_matrix.h> 00012 #include <vnl/vnl_vector.h> 00013 #include <vcl_string.h> 00014 #include <vcl_vector.h> 00015 #include <vcl_cmath.h> 00016 #include <vcl_iosfwd.h> 00017 00018 //: A class to perform some of the functions of a Radial Basis Function Network. 00019 // This is a special case of a mixture model pdf, where the same 00020 // (radially symmetric) pdf kernel is used at each node. 00021 // The nodes are supplied by build(). 00022 // calcWts(w,x) calculates the probabilities that x belongs to each 00023 // node. 00024 // Given a set of n training vectors, x_i (i=0..n-1), a set of internal weights are computed. 00025 // Given a new vector, x, a vector of weights, w, are computed such that 00026 // if x = x_i then w(i+1) = 1, w(j !=i+1) = 0 The sum of the weights 00027 // should always be unity. 00028 // If x is not equal to any training vector, the vector of weights varies 00029 // smoothly. This is useful for interpolation purposes. 00030 // It can also be used to define non-linear transformations between 00031 // vector spaces. If Y is a matrix of n columns, each corresponding to 00032 // a vector in a new space which corresponds to one of the original 00033 // training vectors x_i, then a vector x can be mapped to Yw in the 00034 // new space. (Note: y-space does not have to have the same dimension 00035 // as x space). This class is equivalent to 00036 // the basis of thin-plate spline warping. 00037 // 00038 // I'm not sure if this is exactly an RBF network in the original 00039 // definition. I'll check one day. 00040 class mbl_rbf_network 00041 { 00042 vcl_vector<vnl_vector<double> > x_; 00043 vnl_matrix<double> W_; 00044 double s2_; 00045 00046 bool sum_to_one_; 00047 00048 //: workspace 00049 vnl_vector<double> v_; 00050 00051 double distSqr(const vnl_vector<double>& x, const vnl_vector<double>& y) const; 00052 double rbf(double r2) const 00053 { return r2<=0.0 ? 1.0 : vcl_exp(-r2); } 00054 00055 double rbf(const vnl_vector<double>& x, const vnl_vector<double>& y) 00056 { return rbf(distSqr(x,y)/s2_); } 00057 00058 public: 00059 00060 //: Dflt ctor 00061 mbl_rbf_network(); 00062 00063 //: Build weights given examples x. 00064 // s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s) 00065 // If s<=0 then a suitable s is estimated from the data 00066 void build(const vcl_vector<vnl_vector<double> >& x, double s = -1); 00067 00068 //: Build weights given n examples x[0] to x[n-1]. 00069 // s gives the scaling to use in r2 * vcl_log(r2) r2 = distSqr/(s*s) 00070 // If s<=0 then a suitable s is estimated from the data 00071 void build(const vnl_vector<double>* x, int n, double s = -1); 00072 00073 //: If true, then the returned weights sum to 1.0 00074 bool sumToOne() const { return sum_to_one_; } 00075 00076 //: Set flag. If false, calcWts returns raw weights 00077 void setSumToOne(bool flag); 00078 00079 //: Array of training vectors x, supplied in last build() 00080 const vcl_vector<vnl_vector<double> >& x() const { return x_;} 00081 00082 //: Compute weights for given new_x. 00083 // If new_x = x()(i) then w(i+1)==1, w(j!=i+1)==0 00084 // Otherwise w varies smoothly depending on distance 00085 // of new_x from x()'s 00086 // If sumToOne() then elements of w will sum to 1.0 00087 // otherwise they will sum to <=1.0, decreasing as new_x 00088 // moves away from the training examples x(). 00089 void calcWts(vnl_vector<double>& w, const vnl_vector<double>& new_x); 00090 00091 //: Version number for I/O 00092 short version_no() const; 00093 00094 //: Name of the class 00095 vcl_string is_a() const; 00096 00097 //: True if this is (or is derived from) class named s 00098 bool is_class(vcl_string const& s) const; 00099 00100 //: Print class to os 00101 void print_summary(vcl_ostream& os) const; 00102 00103 //: Save class to binary file stream 00104 void b_write(vsl_b_ostream& bfs) const; 00105 00106 //: Load class from binary file stream 00107 void b_read(vsl_b_istream& bfs); 00108 }; 00109 00110 //: Binary file stream output operator for class reference 00111 void vsl_b_write(vsl_b_ostream& bfs, const mbl_rbf_network& b); 00112 00113 //: Binary file stream input operator for class reference 00114 void vsl_b_read(vsl_b_istream& bfs, mbl_rbf_network& b); 00115 00116 //: Stream output operator for class reference 00117 vcl_ostream& operator<<(vcl_ostream& os,const mbl_rbf_network& b); 00118 00119 #endif //mbl_rbf_network_h_