00001 // This is mul/mbl/mbl_stepwise_regression.h 00002 #ifndef mbl_stepwise_regression_h_ 00003 #define mbl_stepwise_regression_h_ 00004 //: 00005 // \file 00006 // \brief Conduct stepwise regression 00007 // \author Martin Roberts 00008 00009 #include <vcl_set.h> 00010 #include <vnl/vnl_vector.h> 00011 #include <vnl/vnl_matrix.h> 00012 00013 00014 //: Perform the stepwise regression algorithm to determine which subset of variables appear to be significant predictors 00015 class mbl_stepwise_regression 00016 { 00017 //: Forwards or backwards stepwise (add or remove variables) 00018 enum step_mode {eFORWARDS,eBACKWARDS}; 00019 //: The data matrix of x values (predictor variables) 00020 // Each training example is a row, each x-variable dimension corresponds to a column 00021 const vnl_matrix<double>& x_; 00022 //: Vector of dependent y values 00023 const vnl_vector<double>& y_; 00024 //: number of training examples (i.e. number of rows in x_) 00025 unsigned num_examples_; 00026 //: dimensionality - (i.e. number of columns in x_) 00027 unsigned num_vars_; 00028 //: [x,-1]'[x,-1] 00029 vnl_matrix<double> XtX_; 00030 //: [x,-1]'[y,-1] 00031 vnl_vector<double> XtY_; 00032 //: Basis (i.e. all significant variables) 00033 vcl_set<unsigned> basis_ ; 00034 //: All non-basis variables 00035 vcl_set<unsigned> basis_complement_ ; 00036 //: The regression coefficients + constant (final term) 00037 vnl_vector<double> weights_; 00038 00039 //: The residual sum of squares 00040 double rss_; 00041 //: F-ratio significance threshold for adding a variable 00042 double FthreshAdd_; 00043 //: F-ratio significance threshold for removing a variable 00044 // Note must be less than addition threshold or infinite cycling will occur 00045 double FthreshRemove_; 00046 //: forwards or backwards mode 00047 step_mode mode_; 00048 00049 00050 //: Add a new variable and return if significant 00051 // Only added if it makes a significant reduction in RSS, unless forceAdd is set 00052 // Always adds the variable making most difference to RSS 00053 bool add_variable(bool forceAdd=false); 00054 00055 //: Remove a variable that makes no significant difference to RSS 00056 // Will remove the one that causes least change to RSS 00057 bool remove_variable(); 00058 00059 double f_ratio(double rssExtended,double rssBase,unsigned q) 00060 { 00061 double chi2Extra = rssBase - rssExtended ; 00062 double f1 = chi2Extra/double(q); 00063 double f2 = rssExtended/double(num_examples_ - (basis_.size()+q) - 1); 00064 return f1/f2; 00065 } 00066 //: Evaluate F-ratio to test if change in sum of squares was significant 00067 bool test_significance(double rssExtended,double rssBase, double fthresh) 00068 { 00069 return ((f_ratio(rssExtended,rssBase,1) >fthresh) ? true : false); 00070 } 00071 //: Step forward through basis complement, adding in new significant variables and removing those that cease to be significant 00072 void do_forward_stepwise_regression(); 00073 00074 //: Step back and remove all insignificant variables, then try and step forward again 00075 void do_backward_stepwise_regression(); 00076 00077 public: 00078 //: Constructor, note you must supply the data references 00079 // These must remain in scope during algorithm execution 00080 // The data matrix of x values (predictor variables) is arranged as: 00081 // Each training example is a row, each x-variable dimension corresponds to a column 00082 mbl_stepwise_regression(const vnl_matrix<double>& x, 00083 const vnl_vector<double>& y); 00084 00085 //: Run the algorithm as determined by mode 00086 void operator()(); 00087 00088 //: return the basis variables 00089 // I.e. those determined to be significantly correlated with y in stepwise search 00090 const vcl_set<unsigned > basis() const {return basis_;} 00091 00092 //: Set the mode to forwards or backwards 00093 // Note backwards can take a long compute time in a space of high dimension 00094 void set_mode(step_mode mode) {mode_ = mode;} 00095 00096 //: Return the regression coefficients + constant (final term) 00097 const vnl_vector<double >& weights() const {return weights_;} 00098 }; 00099 00100 //:Helper stuff for stepwise regression 00101 namespace mbl_stepwise_regression_helpers 00102 { 00103 //: Do the regression fitting for a given basis instance 00104 class lsfit_this_basis 00105 { 00106 //: The data matrix of x values (predictor variables) 00107 // Each training example is a row, each x-variable dimension corresponds to a column 00108 const vnl_matrix<double>& x_; 00109 //: Vector of dependent y values 00110 const vnl_vector<double>& y_; 00111 //: x'x 00112 const vnl_matrix<double>& XtX_; 00113 //: x'y 00114 const vnl_vector<double>& XtY_; 00115 //: the basis (note ordered by variable index) 00116 vcl_set<unsigned> basis_ ; 00117 //: number of training examples (i.e. number of rows in x_) 00118 unsigned num_examples_; 00119 //: dimensionality - (i.e. number of columns in x_) 00120 unsigned num_vars_; 00121 //: The regression coefficients determine for the significant variables in the basis 00122 // NB These are in the order of the basis variables, e.g. weights_[1] is for the second variable 00123 // Also note the size of weights is one more than the basis, the last term being the constant 00124 vnl_vector<double> weights_; 00125 public: 00126 //: constructor, note supply the data references 00127 lsfit_this_basis(const vnl_matrix<double>& x, 00128 const vnl_vector<double>& y, 00129 const vnl_matrix<double>& XtX, 00130 const vnl_vector<double>& XtY): 00131 x_(x),y_(y),XtX_(XtX),XtY_(XtY) 00132 { 00133 num_examples_ = y.size(); 00134 num_vars_ = x.cols(); 00135 } 00136 //: Set the basis 00137 void set_basis(vcl_set<unsigned>& basis) {basis_ = basis;} 00138 //: return the basis 00139 const vcl_set<unsigned>& basis() const {return basis_;} 00140 00141 //:Try adding variable k to the basis and then fit the extended basis, returning resid sum of squares 00142 // Note the basis is not actually updated, only temporarily for the duration of this call 00143 double add(unsigned k); 00144 00145 //:Try removing variable k from the basis and then fit the extended basis, returning resid sum of squares 00146 // Note the basis is not actually updated, only temporarily for the duration of this call 00147 double remove(unsigned k); 00148 00149 //: Fit the current basis 00150 double operator()(); 00151 00152 //: return the regression coefficients and constant (final term) 00153 const vnl_vector<double >& weights() const {return weights_;} 00154 }; 00155 }; 00156 00157 #endif // mbl_stepwise_regression_h_