contrib/mul/msm/msm_ref_shape_instance.cxx
Go to the documentation of this file.
00001 #include "msm_ref_shape_instance.h"
00002 //:
00003 // \file
00004 // \brief Representation of an instance of a shape model in ref frame.
00005 // \author Tim Cootes
00006 
00007 #include <vsl/vsl_indent.h>
00008 #include <vsl/vsl_binary_io.h>
00009 #include <vnl/io/vnl_io_vector.h>
00010 #include <vnl/algo/vnl_cholesky.h>
00011 #include <vnl/algo/vnl_svd.h>
00012 
00013 #include <vcl_cstdlib.h>  // for vcl_atoi() & vcl_abort()
00014 #include <vcl_iostream.h>
00015 #include <vcl_cassert.h>
00016 
00017 #include <msm/msm_ref_shape_model.h>
00018 #include <msm/msm_no_limiter.h>
00019 #include <mbl/mbl_matxvec.h>
00020 #include <mbl/mbl_matrix_products.h>
00021 
00022 //=======================================================================
00023 // Dflt ctor
00024 //=======================================================================
00025 
00026 msm_ref_shape_instance::msm_ref_shape_instance()
00027   : model_(0),use_prior_(false)
00028 {
00029   param_limiter_ = msm_no_limiter();
00030 }
00031 
00032 //=======================================================================
00033 // Destructor
00034 //=======================================================================
00035 
00036 msm_ref_shape_instance::~msm_ref_shape_instance()
00037 {
00038 }
00039 
00040 //: Set up model (retains pointer to model)
00041 void msm_ref_shape_instance::set_shape_model(const msm_ref_shape_model& model)
00042 {
00043   model_=&model;
00044 
00045   b_.set_size(model.n_modes());
00046   b_.fill(0);
00047 
00048   param_limiter_ = model.param_limiter().clone();
00049 
00050   points_valid_=false;
00051 }
00052 
00053 //: Define limits on parameters (clone taken)
00054 void msm_ref_shape_instance::set_param_limiter(const msm_param_limiter& limiter)
00055 {
00056   param_limiter_ = limiter;
00057 }
00058 
00059 //: When true, use Gaussian prior on params in fit_to_points*
00060 void msm_ref_shape_instance::set_use_prior(bool b)
00061 {
00062   use_prior_ = b;
00063 }
00064 
00065 //: Define parameters
00066 void msm_ref_shape_instance::set_params(const vnl_vector<double>& b)
00067 {
00068   assert(b.size()<=model().n_modes());
00069   b_=b;
00070   points_valid_=false;
00071 }
00072 
00073 //: Set all shape parameters to zero
00074 void msm_ref_shape_instance::set_to_mean()
00075 {
00076   if (b_.size()==0) return;
00077   b_.fill(0.0);
00078   points_valid_=false;
00079 }
00080 
00081 
00082 //: Current shape in model frame (uses lazy evaluation)
00083 const msm_points& msm_ref_shape_instance::points()
00084 {
00085   if (points_valid_) return points_;
00086 
00087   // Need to recalculate points_
00088   if (b_.size()==0)
00089     points_.vector()=model().mean();
00090   else
00091   {
00092     // Use only b_.size() modes.
00093     mbl_matxvec_prod_mv(model().modes(),b_,points_.vector());
00094     points_.vector() += model().mean();
00095   }
00096   points_valid_=true;
00097   return points_;
00098 }
00099 
00100 //: Finds parameters and pose to best match to points
00101 //  All points equally weighted.
00102 //  If res_pt>0, and use_prior(), then effect of
00103 //  Gaussian prior is to scale parameters by
00104 //  mode_var/(mode_var+pt_var).
00105 void msm_ref_shape_instance::fit_to_points(const msm_points& pts,
00106                                            double pt_var)
00107 {
00108   // Catch case when fitting to self
00109   if (&pts == &points_) return;
00110 
00111   if (b_.size()==0) return;
00112 
00113   // Estimate shape parameters
00114   tmp_points_=pts;
00115   tmp_points_.vector()-=model().mean();
00116   mbl_matxvec_prod_vm(tmp_points_.vector(),model().modes(),b_);
00117 
00118   if (use_prior_ && pt_var>0.0)
00119   {
00120     const vnl_vector<double>& var = model().mode_var();
00121     for (unsigned i=0;i<b_.size();++i)
00122       b_[i]*=var[i]/(var[i]+pt_var);
00123   }
00124 
00125   param_limiter().apply_limit(b_);
00126 
00127   points_valid_=false;
00128 }
00129 
00130 void msm_calc_WP(const vnl_matrix<double>& P,
00131                  const vnl_vector<double>& wts,
00132                  unsigned n_modes,
00133                  vnl_matrix<double>& WP)
00134 {
00135   unsigned nr = P.rows();
00136   assert(nr==wts.size()*2);
00137   WP.set_size(nr,n_modes);
00138 
00139   double const*const* PData = P.data_array();
00140   double ** WPData = WP.data_array();
00141   const double* w=wts.data_block();
00142 
00143   for (unsigned i=0;i<nr;i+=2,++w)
00144   {
00145     const double* x=PData[i];
00146     const double* y=PData[i+1];
00147     double* wx=WPData[i];
00148     double* wy=WPData[i+1];
00149     for (unsigned j=0;j<n_modes;++j)
00150     {
00151       wx[j]=w[0]*x[j];
00152       wy[j]=w[0]*y[j];
00153     }
00154   }
00155 }
00156 
00157 // Premultiply P by block diagonal composed of wt_mat
00158 void msm_calc_WP(const vnl_matrix<double>& P,
00159                  const vcl_vector<msm_wt_mat_2d>& wt_mat,
00160                  unsigned n_modes,
00161                  vnl_matrix<double>& WP)
00162 {
00163   unsigned nr = P.rows();
00164   assert(nr==wt_mat.size()*2);
00165   WP.set_size(nr,n_modes);
00166 
00167   double const*const* PData = P.data_array();
00168   double ** WPData = WP.data_array();
00169   vcl_vector<msm_wt_mat_2d>::const_iterator w=wt_mat.begin();
00170 
00171   for (unsigned i=0;i<nr;i+=2,++w)
00172   {
00173     const double* x=PData[i];
00174     const double* y=PData[i+1];
00175     double* wx=WPData[i];
00176     double* wy=WPData[i+1];
00177     double w11=w->m11();
00178     double w12=w->m12();
00179     double w22=w->m22();
00180     for (unsigned j=0;j<n_modes;++j)
00181     {
00182       wx[j]=w11*x[j] + w12*y[j];
00183       wy[j]=w12*x[j] + w22*y[j];
00184     }
00185   }
00186 }
00187 
00188 // Solves Mb=rhs for b where M is assumed symmetric
00189 void msm_solve_sym_eqn(const vnl_matrix<double>& M,
00190                        const vnl_vector<double>& rhs,
00191                        vnl_vector<double>& b)
00192 {
00193   vnl_cholesky chol(M,vnl_cholesky::estimate_condition);
00194   if (chol.rcond()>1.0e-6)
00195   {
00196     chol.solve(rhs,&b);
00197   }
00198   else
00199   {
00200     // Solve using SVD
00201     double tol=1e-8;
00202     vnl_svd<double> svd(M);
00203     svd.zero_out_relative(tol);
00204     b = svd.solve(rhs);
00205   }
00206 }
00207 
00208 // Solve weighted version of Pb=dx, ie P'WPb=P'Wdx
00209 void msm_solve_for_b(const vnl_matrix<double>& P,
00210                      const vnl_vector<double>& var,
00211                      const vnl_vector<double>& wts,
00212                      const vnl_vector<double>& dx,
00213                      unsigned n_modes,
00214                      vnl_vector<double>& b, bool use_prior)
00215 {
00216   vnl_matrix<double> WP;
00217   msm_calc_WP(P,wts,n_modes,WP);
00218 
00219   vnl_vector<double> PtWdx(n_modes);
00220   mbl_matxvec_prod_vm(dx,WP,PtWdx);
00221 
00222   vnl_matrix<double> PtWP;
00223   mbl_matrix_product_at_b(PtWP,P,WP,n_modes);
00224 
00225   if (use_prior)  // Add 1/var to diagonal of PtWP
00226     for (unsigned i=0;i<n_modes;++i) PtWP(i,i)+=1.0/var(i);
00227 
00228   // Solves (PtWP)b = PtWdx
00229   msm_solve_sym_eqn(PtWP,PtWdx,b);
00230 }
00231 
00232 // Solve weighted version of Pb=dx, ie P'WPb=P'Wdx
00233 // W is block diagonal, with blocks wt_mat[i] (symmetrix 2x2)
00234 void msm_solve_for_b(const vnl_matrix<double>& P,
00235                      const vnl_vector<double>& var,
00236                      const vcl_vector<msm_wt_mat_2d>& wt_mat,
00237                      const vnl_vector<double>& dx,
00238                      unsigned n_modes,
00239                      vnl_vector<double>& b, bool use_prior)
00240 {
00241   vnl_matrix<double> WP;
00242   msm_calc_WP(P,wt_mat,n_modes,WP);
00243 
00244   vnl_vector<double> PtWdx(n_modes);
00245   mbl_matxvec_prod_vm(dx,WP,PtWdx);
00246 
00247   vnl_matrix<double> PtWP;
00248   mbl_matrix_product_at_b(PtWP,P,WP,n_modes);
00249 
00250   if (use_prior)  // Add 1/var to diagonal of PtWP
00251     for (unsigned i=0;i<n_modes;++i) PtWP(i,i)+=1.0/var(i);
00252 
00253   // Solves (PtWP)b = PtWdx
00254   msm_solve_sym_eqn(PtWP,PtWdx,b);
00255 }
00256 
00257 //: Finds parameters and pose to best match to points
00258 //  Errors on point i are weighted by wts[i]
00259 void msm_ref_shape_instance::fit_to_points_wt(const msm_points& pts,
00260                                               const vnl_vector<double>& wts)
00261 {
00262   // Catch case when fitting to self
00263   if (&pts == &points_) return;
00264   if (b_.size()==0) return;
00265 
00266   tmp_points_.vector()=pts.vector();
00267   tmp_points_.vector()-=model().mean();
00268 
00269   // Now must solve weighted linear equation P'WPb=P'Wdx
00270   msm_solve_for_b(model().modes(),model().mode_var(),
00271                   wts,tmp_points_.vector(),
00272                   b_.size(),b_,use_prior_);
00273 
00274   param_limiter().apply_limit(b_);
00275 
00276   points_valid_=false;
00277 }
00278 
00279 #if 0
00280 // Calculates W2=T'WT where T is 2x2 matrix (a,-b;b,a)
00281 void msm_transform_wt_mat(const vnl_double_2x2& W,
00282                           double a, double b, vnl_double_2x2& W2)
00283 {
00284   W2(0,0)=a*a*W[0][0]+2*a*b*W[0][1]+b*b*W[1][1];
00285   W2(0,1)=a*a*W[0][1]+a*b*(W(1,1)-W[0][0])-b*b*W[0][1];
00286   W2(1,0)=W2(0,1);
00287   W2(1,1)=a*a*W[1][1]-2*a*b*W[0][1]+b*b*W[0][0];
00288 }
00289 #endif // 0
00290 
00291 //: Finds parameters and pose to best match to points
00292 //  Errors on point i are weighted by wt_mat[i] in target frame
00293 void msm_ref_shape_instance::fit_to_points_wt_mat(const msm_points& pts,
00294                                                   const vcl_vector<msm_wt_mat_2d>& wt_mat)
00295 {
00296   // Catch case when fitting to self
00297   if (&pts == &points_) return;
00298   if (b_.size()==0) return;
00299 
00300   assert(wt_mat.size()==model().size());
00301 
00302   tmp_points_.vector()=pts.vector();
00303   tmp_points_.vector()-=model().mean();
00304 
00305   // Now must solve weighted linear equation P'WPb=P'Wdx
00306   msm_solve_for_b(model().modes(),model().mode_var(),
00307                   wt_mat,tmp_points_.vector(),
00308                   b_.size(),b_,use_prior_);
00309 
00310   param_limiter().apply_limit(b_);
00311 
00312   points_valid_=false;
00313 }
00314 
00315 
00316 //=======================================================================
00317 // Method: version_no
00318 //=======================================================================
00319 
00320 short msm_ref_shape_instance::version_no() const
00321 {
00322   return 1;
00323 }
00324 
00325 //=======================================================================
00326 // Method: is_a
00327 //=======================================================================
00328 
00329 vcl_string msm_ref_shape_instance::is_a() const
00330 {
00331   return vcl_string("msm_ref_shape_instance");
00332 }
00333 
00334 //=======================================================================
00335 // Method: print
00336 //=======================================================================
00337 
00338   // required if data is present in this class
00339 void msm_ref_shape_instance::print_summary(vcl_ostream& os) const
00340 {
00341 }
00342 
00343 //=======================================================================
00344 // Method: save
00345 //=======================================================================
00346 
00347   // required if data is present in this class
00348 void msm_ref_shape_instance::b_write(vsl_b_ostream& bfs) const
00349 {
00350   vsl_b_write(bfs,version_no());
00351   vsl_b_write(bfs,b_);
00352   vsl_b_write(bfs,param_limiter_);
00353   vsl_b_write(bfs,use_prior_);
00354 }
00355 
00356 //=======================================================================
00357 // Method: load
00358 //=======================================================================
00359 
00360   // required if data is present in this class
00361 void msm_ref_shape_instance::b_read(vsl_b_istream& bfs)
00362 {
00363   short version;
00364   vsl_b_read(bfs,version);
00365   switch (version)
00366   {
00367     case (1):
00368       vsl_b_read(bfs,b_);
00369       vsl_b_read(bfs,param_limiter_);
00370       vsl_b_read(bfs,use_prior_);
00371       break;
00372     default:
00373       vcl_cerr << "msm_ref_shape_instance::b_read() :\n"
00374                << "Unexpected version number " << version << vcl_endl;
00375       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00376       return;
00377   }
00378 
00379   points_valid_=false;
00380   points_valid_=false;
00381 }
00382 
00383 
00384 //=======================================================================
00385 // Associated function: operator<<
00386 //=======================================================================
00387 
00388 void vsl_b_write(vsl_b_ostream& bfs, const msm_ref_shape_instance& b)
00389 {
00390   b.b_write(bfs);
00391 }
00392 
00393 //=======================================================================
00394 // Associated function: operator>>
00395 //=======================================================================
00396 
00397 void vsl_b_read(vsl_b_istream& bfs, msm_ref_shape_instance& b)
00398 {
00399   b.b_read(bfs);
00400 }
00401 
00402 //=======================================================================
00403 // Associated function: operator<<
00404 //=======================================================================
00405 
00406 vcl_ostream& operator<<(vcl_ostream& os,const msm_ref_shape_instance& b)
00407 {
00408   os << b.is_a() << ": ";
00409   vsl_indent_inc(os);
00410   b.print_summary(os);
00411   vsl_indent_dec(os);
00412   return os;
00413 }
00414 
00415 //: Stream output operator for class reference
00416 void vsl_print_summary(vcl_ostream& os,const msm_ref_shape_instance& b)
00417 {
00418  os << b;
00419 }