00001 #include "msm_ref_shape_instance.h"
00002
00003
00004
00005
00006
00007 #include <vsl/vsl_indent.h>
00008 #include <vsl/vsl_binary_io.h>
00009 #include <vnl/io/vnl_io_vector.h>
00010 #include <vnl/algo/vnl_cholesky.h>
00011 #include <vnl/algo/vnl_svd.h>
00012
00013 #include <vcl_cstdlib.h>
00014 #include <vcl_iostream.h>
00015 #include <vcl_cassert.h>
00016
00017 #include <msm/msm_ref_shape_model.h>
00018 #include <msm/msm_no_limiter.h>
00019 #include <mbl/mbl_matxvec.h>
00020 #include <mbl/mbl_matrix_products.h>
00021
00022
00023
00024
00025
00026 msm_ref_shape_instance::msm_ref_shape_instance()
00027 : model_(0),use_prior_(false)
00028 {
00029 param_limiter_ = msm_no_limiter();
00030 }
00031
00032
00033
00034
00035
00036 msm_ref_shape_instance::~msm_ref_shape_instance()
00037 {
00038 }
00039
00040
00041 void msm_ref_shape_instance::set_shape_model(const msm_ref_shape_model& model)
00042 {
00043 model_=&model;
00044
00045 b_.set_size(model.n_modes());
00046 b_.fill(0);
00047
00048 param_limiter_ = model.param_limiter().clone();
00049
00050 points_valid_=false;
00051 }
00052
00053
00054 void msm_ref_shape_instance::set_param_limiter(const msm_param_limiter& limiter)
00055 {
00056 param_limiter_ = limiter;
00057 }
00058
00059
00060 void msm_ref_shape_instance::set_use_prior(bool b)
00061 {
00062 use_prior_ = b;
00063 }
00064
00065
00066 void msm_ref_shape_instance::set_params(const vnl_vector<double>& b)
00067 {
00068 assert(b.size()<=model().n_modes());
00069 b_=b;
00070 points_valid_=false;
00071 }
00072
00073
00074 void msm_ref_shape_instance::set_to_mean()
00075 {
00076 if (b_.size()==0) return;
00077 b_.fill(0.0);
00078 points_valid_=false;
00079 }
00080
00081
00082
00083 const msm_points& msm_ref_shape_instance::points()
00084 {
00085 if (points_valid_) return points_;
00086
00087
00088 if (b_.size()==0)
00089 points_.vector()=model().mean();
00090 else
00091 {
00092
00093 mbl_matxvec_prod_mv(model().modes(),b_,points_.vector());
00094 points_.vector() += model().mean();
00095 }
00096 points_valid_=true;
00097 return points_;
00098 }
00099
00100
00101
00102
00103
00104
00105 void msm_ref_shape_instance::fit_to_points(const msm_points& pts,
00106 double pt_var)
00107 {
00108
00109 if (&pts == &points_) return;
00110
00111 if (b_.size()==0) return;
00112
00113
00114 tmp_points_=pts;
00115 tmp_points_.vector()-=model().mean();
00116 mbl_matxvec_prod_vm(tmp_points_.vector(),model().modes(),b_);
00117
00118 if (use_prior_ && pt_var>0.0)
00119 {
00120 const vnl_vector<double>& var = model().mode_var();
00121 for (unsigned i=0;i<b_.size();++i)
00122 b_[i]*=var[i]/(var[i]+pt_var);
00123 }
00124
00125 param_limiter().apply_limit(b_);
00126
00127 points_valid_=false;
00128 }
00129
00130 void msm_calc_WP(const vnl_matrix<double>& P,
00131 const vnl_vector<double>& wts,
00132 unsigned n_modes,
00133 vnl_matrix<double>& WP)
00134 {
00135 unsigned nr = P.rows();
00136 assert(nr==wts.size()*2);
00137 WP.set_size(nr,n_modes);
00138
00139 double const*const* PData = P.data_array();
00140 double ** WPData = WP.data_array();
00141 const double* w=wts.data_block();
00142
00143 for (unsigned i=0;i<nr;i+=2,++w)
00144 {
00145 const double* x=PData[i];
00146 const double* y=PData[i+1];
00147 double* wx=WPData[i];
00148 double* wy=WPData[i+1];
00149 for (unsigned j=0;j<n_modes;++j)
00150 {
00151 wx[j]=w[0]*x[j];
00152 wy[j]=w[0]*y[j];
00153 }
00154 }
00155 }
00156
00157
00158 void msm_calc_WP(const vnl_matrix<double>& P,
00159 const vcl_vector<msm_wt_mat_2d>& wt_mat,
00160 unsigned n_modes,
00161 vnl_matrix<double>& WP)
00162 {
00163 unsigned nr = P.rows();
00164 assert(nr==wt_mat.size()*2);
00165 WP.set_size(nr,n_modes);
00166
00167 double const*const* PData = P.data_array();
00168 double ** WPData = WP.data_array();
00169 vcl_vector<msm_wt_mat_2d>::const_iterator w=wt_mat.begin();
00170
00171 for (unsigned i=0;i<nr;i+=2,++w)
00172 {
00173 const double* x=PData[i];
00174 const double* y=PData[i+1];
00175 double* wx=WPData[i];
00176 double* wy=WPData[i+1];
00177 double w11=w->m11();
00178 double w12=w->m12();
00179 double w22=w->m22();
00180 for (unsigned j=0;j<n_modes;++j)
00181 {
00182 wx[j]=w11*x[j] + w12*y[j];
00183 wy[j]=w12*x[j] + w22*y[j];
00184 }
00185 }
00186 }
00187
00188
00189 void msm_solve_sym_eqn(const vnl_matrix<double>& M,
00190 const vnl_vector<double>& rhs,
00191 vnl_vector<double>& b)
00192 {
00193 vnl_cholesky chol(M,vnl_cholesky::estimate_condition);
00194 if (chol.rcond()>1.0e-6)
00195 {
00196 chol.solve(rhs,&b);
00197 }
00198 else
00199 {
00200
00201 double tol=1e-8;
00202 vnl_svd<double> svd(M);
00203 svd.zero_out_relative(tol);
00204 b = svd.solve(rhs);
00205 }
00206 }
00207
00208
00209 void msm_solve_for_b(const vnl_matrix<double>& P,
00210 const vnl_vector<double>& var,
00211 const vnl_vector<double>& wts,
00212 const vnl_vector<double>& dx,
00213 unsigned n_modes,
00214 vnl_vector<double>& b, bool use_prior)
00215 {
00216 vnl_matrix<double> WP;
00217 msm_calc_WP(P,wts,n_modes,WP);
00218
00219 vnl_vector<double> PtWdx(n_modes);
00220 mbl_matxvec_prod_vm(dx,WP,PtWdx);
00221
00222 vnl_matrix<double> PtWP;
00223 mbl_matrix_product_at_b(PtWP,P,WP,n_modes);
00224
00225 if (use_prior)
00226 for (unsigned i=0;i<n_modes;++i) PtWP(i,i)+=1.0/var(i);
00227
00228
00229 msm_solve_sym_eqn(PtWP,PtWdx,b);
00230 }
00231
00232
00233
00234 void msm_solve_for_b(const vnl_matrix<double>& P,
00235 const vnl_vector<double>& var,
00236 const vcl_vector<msm_wt_mat_2d>& wt_mat,
00237 const vnl_vector<double>& dx,
00238 unsigned n_modes,
00239 vnl_vector<double>& b, bool use_prior)
00240 {
00241 vnl_matrix<double> WP;
00242 msm_calc_WP(P,wt_mat,n_modes,WP);
00243
00244 vnl_vector<double> PtWdx(n_modes);
00245 mbl_matxvec_prod_vm(dx,WP,PtWdx);
00246
00247 vnl_matrix<double> PtWP;
00248 mbl_matrix_product_at_b(PtWP,P,WP,n_modes);
00249
00250 if (use_prior)
00251 for (unsigned i=0;i<n_modes;++i) PtWP(i,i)+=1.0/var(i);
00252
00253
00254 msm_solve_sym_eqn(PtWP,PtWdx,b);
00255 }
00256
00257
00258
00259 void msm_ref_shape_instance::fit_to_points_wt(const msm_points& pts,
00260 const vnl_vector<double>& wts)
00261 {
00262
00263 if (&pts == &points_) return;
00264 if (b_.size()==0) return;
00265
00266 tmp_points_.vector()=pts.vector();
00267 tmp_points_.vector()-=model().mean();
00268
00269
00270 msm_solve_for_b(model().modes(),model().mode_var(),
00271 wts,tmp_points_.vector(),
00272 b_.size(),b_,use_prior_);
00273
00274 param_limiter().apply_limit(b_);
00275
00276 points_valid_=false;
00277 }
00278
00279 #if 0
00280
00281 void msm_transform_wt_mat(const vnl_double_2x2& W,
00282 double a, double b, vnl_double_2x2& W2)
00283 {
00284 W2(0,0)=a*a*W[0][0]+2*a*b*W[0][1]+b*b*W[1][1];
00285 W2(0,1)=a*a*W[0][1]+a*b*(W(1,1)-W[0][0])-b*b*W[0][1];
00286 W2(1,0)=W2(0,1);
00287 W2(1,1)=a*a*W[1][1]-2*a*b*W[0][1]+b*b*W[0][0];
00288 }
00289 #endif // 0
00290
00291
00292
00293 void msm_ref_shape_instance::fit_to_points_wt_mat(const msm_points& pts,
00294 const vcl_vector<msm_wt_mat_2d>& wt_mat)
00295 {
00296
00297 if (&pts == &points_) return;
00298 if (b_.size()==0) return;
00299
00300 assert(wt_mat.size()==model().size());
00301
00302 tmp_points_.vector()=pts.vector();
00303 tmp_points_.vector()-=model().mean();
00304
00305
00306 msm_solve_for_b(model().modes(),model().mode_var(),
00307 wt_mat,tmp_points_.vector(),
00308 b_.size(),b_,use_prior_);
00309
00310 param_limiter().apply_limit(b_);
00311
00312 points_valid_=false;
00313 }
00314
00315
00316
00317
00318
00319
00320 short msm_ref_shape_instance::version_no() const
00321 {
00322 return 1;
00323 }
00324
00325
00326
00327
00328
00329 vcl_string msm_ref_shape_instance::is_a() const
00330 {
00331 return vcl_string("msm_ref_shape_instance");
00332 }
00333
00334
00335
00336
00337
00338
00339 void msm_ref_shape_instance::print_summary(vcl_ostream& os) const
00340 {
00341 }
00342
00343
00344
00345
00346
00347
00348 void msm_ref_shape_instance::b_write(vsl_b_ostream& bfs) const
00349 {
00350 vsl_b_write(bfs,version_no());
00351 vsl_b_write(bfs,b_);
00352 vsl_b_write(bfs,param_limiter_);
00353 vsl_b_write(bfs,use_prior_);
00354 }
00355
00356
00357
00358
00359
00360
00361 void msm_ref_shape_instance::b_read(vsl_b_istream& bfs)
00362 {
00363 short version;
00364 vsl_b_read(bfs,version);
00365 switch (version)
00366 {
00367 case (1):
00368 vsl_b_read(bfs,b_);
00369 vsl_b_read(bfs,param_limiter_);
00370 vsl_b_read(bfs,use_prior_);
00371 break;
00372 default:
00373 vcl_cerr << "msm_ref_shape_instance::b_read() :\n"
00374 << "Unexpected version number " << version << vcl_endl;
00375 bfs.is().clear(vcl_ios::badbit);
00376 return;
00377 }
00378
00379 points_valid_=false;
00380 points_valid_=false;
00381 }
00382
00383
00384
00385
00386
00387
00388 void vsl_b_write(vsl_b_ostream& bfs, const msm_ref_shape_instance& b)
00389 {
00390 b.b_write(bfs);
00391 }
00392
00393
00394
00395
00396
00397 void vsl_b_read(vsl_b_istream& bfs, msm_ref_shape_instance& b)
00398 {
00399 b.b_read(bfs);
00400 }
00401
00402
00403
00404
00405
00406 vcl_ostream& operator<<(vcl_ostream& os,const msm_ref_shape_instance& b)
00407 {
00408 os << b.is_a() << ": ";
00409 vsl_indent_inc(os);
00410 b.print_summary(os);
00411 vsl_indent_dec(os);
00412 return os;
00413 }
00414
00415
00416 void vsl_print_summary(vcl_ostream& os,const msm_ref_shape_instance& b)
00417 {
00418 os << b;
00419 }