contrib/mul/msm/msm_shape_model_builder.cxx
Go to the documentation of this file.
00001 #include "msm_shape_model_builder.h"
00002 //:
00003 // \file
00004 // \brief Object to build a msm_shape_model
00005 // \author Tim Cootes
00006 
00007 #include <vcl_iostream.h>
00008 #include <vsl/vsl_indent.h>
00009 #include <vsl/vsl_binary_io.h>
00010 #include <vnl/io/vnl_io_vector.h>
00011 #include <vnl/io/vnl_io_matrix.h>
00012 #include <vcl_cstdlib.h>  // for vcl_atoi() & vcl_abort()
00013 #include <mbl/mbl_data_array_wrapper.h>
00014 #include <mcal/mcal_pca.h>
00015 #include <mcal/mcal_extract_mode.h>
00016 #include <mbl/mbl_exception.h>
00017 
00018 //=======================================================================
00019 // Dflt ctor
00020 //=======================================================================
00021 
00022 msm_shape_model_builder::msm_shape_model_builder()
00023   : var_prop_(0.98),min_modes_(0),max_modes_(9999)
00024 {
00025 }
00026 
00027 //=======================================================================
00028 // Destructor
00029 //=======================================================================
00030 
00031 msm_shape_model_builder::~msm_shape_model_builder()
00032 {
00033 }
00034 
00035 //: Set up model
00036 void msm_shape_model_builder::set_aligner(
00037            const msm_aligner& aligner)
00038 {
00039   aligner_ = aligner;
00040 }
00041 
00042 //: Define parameter limiter.
00043 void msm_shape_model_builder::set_param_limiter(const msm_param_limiter& p)
00044 {
00045   param_limiter_=p;
00046 }
00047 
00048 void msm_shape_model_builder::set_mode_choice(unsigned min, unsigned max,
00049                                               double var_proportion)
00050 {
00051   min_modes_ = min;
00052   max_modes_ = max;
00053   var_prop_ = var_proportion;
00054 }
00055 
00056 
00057 //: Builds the model from the supplied examples
00058 void msm_shape_model_builder::build_model(
00059                    const vcl_vector<msm_points>& shapes,
00060                    msm_shape_model& shape_model)
00061 {
00062   // Align shapes and estimate mean pose
00063   msm_points ref_mean_shape;
00064   vcl_vector<vnl_vector<double> > pose_to_ref;
00065   vnl_vector<double> average_pose;
00066   aligner().align_set(shapes,ref_mean_shape,pose_to_ref,average_pose);
00067 
00068   // Generate vectors corresponding to aligned shapes
00069   unsigned n = shapes.size();
00070   vcl_vector<vnl_vector<double> > aligned_shape_vec(n);
00071   msm_points aligned_shape;
00072   for (unsigned i=0;i<n;++i)
00073   {
00074     aligner().apply_transform(shapes[i],pose_to_ref[i],aligned_shape);
00075     aligned_shape_vec[i]=aligned_shape.vector();
00076   }
00077 
00078   mcal_pca pca;
00079   pca.set_mode_choice(min_modes_,max_modes_,var_prop_);
00080 
00081   vnl_matrix<double> modes;
00082   vnl_vector<double> mode_var;
00083 
00084   mbl_data_array_wrapper<vnl_vector<double> >
00085     data(&aligned_shape_vec[0],n);
00086 
00087   pca.build_about_mean(data,ref_mean_shape.vector(),
00088                        modes,mode_var);
00089 
00090   param_limiter_->set_param_var(mode_var);
00091 
00092   shape_model.set(ref_mean_shape,modes,mode_var,average_pose,
00093                   aligner(),param_limiter());
00094 }
00095 
00096 //: Each point p controls two elements 2p,2p+1
00097 inline void msm_elements_from_pts_used(
00098                    const vcl_vector<vcl_vector<unsigned> >& pts_used,
00099                    vcl_vector<vcl_vector<unsigned> >& used)
00100 {
00101   used.resize(pts_used.size());
00102   for (unsigned i=0;i<pts_used.size();++i)
00103   {
00104     used[i].empty();
00105     used[i].reserve(2*pts_used[i].size());
00106     for (unsigned j=0;j<pts_used[i].size();++j)
00107     {
00108       used[i].push_back(2*pts_used[i][j]);
00109       used[i].push_back(2*pts_used[i][j]+1);
00110     }
00111   }
00112 }
00113 
00114 //: Builds the model, using subsets of elements for some modes
00115 //  Builds a shape model, allowing control of which elements may
00116 //  be varied in some of the modes.  This allows construction
00117 //  of models where some groups of points are semi-independent
00118 //  of the others.
00119 //  \param pts_used[i] indicates the set of elements to be used for
00120 //  mode i (or all if \p pts_used[i] is empty).
00121 //  Modes beyond \p pts_used.size() will use all elements.
00122 //  Builds at least \p pts_used.size() modes. Number defined by
00123 //  max_modes and var_prop.
00124 void msm_shape_model_builder::build_model(
00125                   const vcl_vector<msm_points>& shapes,
00126                   const vcl_vector<vcl_vector<unsigned> >& pts_used,
00127                   msm_shape_model& shape_model)
00128 {
00129   // Align shapes and estimate mean pose
00130   msm_points ref_mean_shape;
00131   vcl_vector<vnl_vector<double> > pose_to_ref;
00132   vnl_vector<double> average_pose;
00133   aligner().align_set(shapes,ref_mean_shape,pose_to_ref,average_pose);
00134 
00135   // Generate vectors corresponding to aligned shapes
00136   unsigned n = shapes.size();
00137   vcl_vector<vnl_vector<double> > dshape_vec(n);
00138   msm_points aligned_shape;
00139   for (unsigned i=0;i<n;++i)
00140   {
00141     aligner().apply_transform(shapes[i],pose_to_ref[i],aligned_shape);
00142     dshape_vec[i]=aligned_shape.vector()-ref_mean_shape.vector();
00143   }
00144 
00145   // Set up indication for which elements to be used
00146   // pt i corresponds to elements 2i,2i+1
00147   vcl_vector<vcl_vector<unsigned> > used;
00148   msm_elements_from_pts_used(pts_used,used);
00149 
00150   vnl_matrix<double> modes;
00151   vnl_vector<double> mode_var;
00152 
00153   mcal_extract_modes(dshape_vec,used,
00154                      max_modes_,var_prop_,
00155                      modes,mode_var);
00156 
00157   param_limiter_->set_param_var(mode_var);
00158 
00159   shape_model.set(ref_mean_shape,modes,mode_var,average_pose,
00160                   aligner(),param_limiter());
00161 }
00162 
00163 //: Builds the model from the points loaded from given files
00164 void msm_shape_model_builder::build_from_files(
00165                    const vcl_string& points_dir,
00166                    const vcl_vector<vcl_string>& filenames,
00167                    msm_shape_model& shape_model)
00168 {
00169   vcl_vector<msm_points> shapes(filenames.size());
00170   msm_load_shapes(points_dir,filenames,shapes);
00171   build_model(shapes,shape_model);
00172 }
00173 
00174 //: Counts number of examples with each class ID
00175 //  Assumes IDs run from 0.  Ignores any elements with \p id[i]<0.
00176 //  On exit n_per_class[j] indicates number of class j
00177 //  It is resized to cope with the largest ID number. Some elements
00178 //  may be zero.
00179 static void msm_count_classes(const vcl_vector<int>& id,
00180                               vcl_vector<unsigned>& n_per_class)
00181 {
00182   int max_id = 0;
00183   for (unsigned i=0;i<id.size();++i)
00184     if (id[i]>max_id) max_id=id[i];
00185 
00186   n_per_class.resize(1+max_id,0u);
00187 
00188   for (unsigned i=0;i<id.size();++i)
00189     if (id[i]>=0) n_per_class[id[i]]++;
00190 }
00191 
00192 
00193 //: Builds shape model from within-class variation
00194 //  \param shape[i] belongs to class \p id[i].
00195 //  Aligns all shapes to a common mean.
00196 //  Computes the average covariance about each class mean,
00197 //  and builds shape modes from this.
00198 //
00199 //  If \p id[i]<0, then shape is
00200 //  used for building global mean, but not for within class model.
00201 void msm_shape_model_builder::build_within_class_model(
00202                    const vcl_vector<msm_points>& shapes,
00203                    const vcl_vector<int>& id,
00204                    msm_shape_model& shape_model)
00205 {
00206   // Align shapes and estimate mean pose
00207   msm_points ref_mean_shape;
00208   vcl_vector<vnl_vector<double> > pose_to_ref;
00209   vnl_vector<double> average_pose;
00210   aligner().align_set(shapes,ref_mean_shape,pose_to_ref,average_pose);
00211 
00212   vcl_vector<unsigned> n_per_class;
00213   msm_count_classes(id,n_per_class);
00214   vcl_vector<vnl_vector<double> > class_mean(n_per_class.size());
00215 
00216   // Initialise sums for class means
00217   for (unsigned j=0;j<n_per_class.size();++j)
00218     if (n_per_class[j]>0)
00219     {
00220       class_mean[j].set_size(2*ref_mean_shape.size());
00221       class_mean[j].fill(0.0);
00222     }
00223 
00224   // Generate vectors corresponding to aligned shapes
00225   unsigned n = shapes.size();
00226   vcl_vector<vnl_vector<double> > dshape_vec;
00227   vcl_vector<int> valid_id;
00228   dshape_vec.reserve(n);
00229   valid_id.reserve(n);
00230 
00231   msm_points aligned_shape;
00232   for (unsigned i=0;i<n;++i)
00233   {
00234     if (id[i]<0) continue;  // Ignore unknown class id
00235     aligner().apply_transform(shapes[i],pose_to_ref[i],aligned_shape);
00236     dshape_vec.push_back(aligned_shape.vector());
00237     valid_id.push_back(id[i]);
00238 
00239     class_mean[id[i]]+=aligned_shape.vector();
00240   }
00241 
00242   // Compute the mean for each class from the sums
00243   for (unsigned j=0;j<n_per_class.size();++j)
00244     if (n_per_class[j]>0) class_mean[j]/=n_per_class[j];
00245 
00246   // Remove mean from each example
00247   for (unsigned i=0;i<dshape_vec.size();++i)
00248     dshape_vec[i]-=class_mean[valid_id[i]];
00249 
00250   // Vectors are now about a zero mean.
00251   // Apply PCA to this data to compute the modes.
00252   mcal_pca pca;
00253   pca.set_mode_choice(min_modes_,max_modes_,var_prop_);
00254 
00255   vnl_matrix<double> modes;
00256   vnl_vector<double> mode_var;
00257 
00258   mbl_data_array_wrapper<vnl_vector<double> > data(dshape_vec);
00259   vnl_vector<double> zero_mean(2*ref_mean_shape.size(),0.0);
00260 
00261   pca.build_about_mean(data,zero_mean, modes,mode_var);
00262 
00263   param_limiter_->set_param_var(mode_var);
00264 
00265   shape_model.set(ref_mean_shape,modes,mode_var,average_pose,
00266                   aligner(),param_limiter());
00267 }
00268 
00269 //: Builds shape model from within-class variation
00270 //  \param shape[i] belongs to class \p id[i].
00271 //  Aligns all shapes to a common mean.
00272 //  Computes the average covariance about each class mean,
00273 //  and builds shape modes from this.
00274 //
00275 //  If \p id[i]<0, then shape is
00276 //  used for building global mean, but not for within class model.
00277 //
00278 //  \param pts_used[i] indicates which points will be controlled by mode i.
00279 void msm_shape_model_builder::build_within_class_model(
00280                    const vcl_vector<msm_points>& shapes,
00281                    const vcl_vector<int>& id,
00282                    const vcl_vector<vcl_vector<unsigned> >& pts_used,
00283                    msm_shape_model& shape_model)
00284 {
00285   // Align shapes and estimate mean pose
00286   msm_points ref_mean_shape;
00287   vcl_vector<vnl_vector<double> > pose_to_ref;
00288   vnl_vector<double> average_pose;
00289   aligner().align_set(shapes,ref_mean_shape,pose_to_ref,average_pose);
00290 
00291   vcl_vector<unsigned> n_per_class;
00292   msm_count_classes(id,n_per_class);
00293   vcl_vector<vnl_vector<double> > class_mean(n_per_class.size());
00294 
00295   // Initialise sums for class means
00296   for (unsigned j=0;j<n_per_class.size();++j)
00297     if (n_per_class[j]>0)
00298     {
00299       class_mean[j].set_size(2*ref_mean_shape.size());
00300       class_mean[j].fill(0.0);
00301     }
00302 
00303   // Generate vectors corresponding to aligned shapes
00304   unsigned n = shapes.size();
00305   vcl_vector<vnl_vector<double> > dshape_vec;
00306   vcl_vector<int> valid_id;
00307   dshape_vec.reserve(n);
00308   valid_id.reserve(n);
00309 
00310   msm_points aligned_shape;
00311   for (unsigned i=0;i<n;++i)
00312   {
00313     if (id[i]<0) continue;  // Ignore unknown class id
00314     aligner().apply_transform(shapes[i],pose_to_ref[i],aligned_shape);
00315     dshape_vec.push_back(aligned_shape.vector());
00316     valid_id.push_back(id[i]);
00317 
00318     class_mean[id[i]]+=aligned_shape.vector();
00319   }
00320 
00321   // Compute the mean for each class from the sums
00322   for (unsigned j=0;j<n_per_class.size();++j)
00323     if (n_per_class[j]>0) class_mean[j]/=n_per_class[j];
00324 
00325   // Remove mean from each example
00326   for (unsigned i=0;i<dshape_vec.size();++i)
00327     dshape_vec[i]-=class_mean[valid_id[i]];
00328 
00329   // Set up indication for which elements to be used
00330   // pt i corresponds to elements 2i,2i+1
00331   vcl_vector<vcl_vector<unsigned> > used;
00332   msm_elements_from_pts_used(pts_used,used);
00333 
00334   vnl_matrix<double> modes;
00335   vnl_vector<double> mode_var;
00336 
00337   mcal_extract_modes(dshape_vec,used,
00338                      max_modes_,var_prop_,
00339                      modes,mode_var);
00340 
00341   param_limiter_->set_param_var(mode_var);
00342 
00343   shape_model.set(ref_mean_shape,modes,mode_var,average_pose,
00344                   aligner(),param_limiter());
00345 }
00346 
00347 
00348 //: Loads all shapes from \p points_dir/filenames[i].
00349 //  Throws mbl_exception_parse_error if fails.
00350 void msm_load_shapes(const vcl_string& points_dir,
00351                      const vcl_vector<vcl_string>& filenames,
00352                      vcl_vector<msm_points>& shapes)
00353 {
00354   unsigned n=filenames.size();
00355   shapes.resize(n);
00356   for (unsigned i=0;i<n;++i)
00357   {
00358     vcl_string path = points_dir+"/"+filenames[i];
00359     if (!shapes[i].read_text_file(path))
00360     {
00361       mbl_exception_parse_error x("Failed to load points from "+path);
00362       mbl_exception_error(x);
00363     }
00364   }
00365 }
00366 
00367 //=======================================================================
00368 // Method: version_no
00369 //=======================================================================
00370 
00371 short msm_shape_model_builder::version_no() const
00372 {
00373   return 1;
00374 }
00375 
00376 //=======================================================================
00377 // Method: is_a
00378 //=======================================================================
00379 
00380 vcl_string msm_shape_model_builder::is_a() const
00381 {
00382   return vcl_string("msm_shape_model_builder");
00383 }
00384 
00385 //=======================================================================
00386 // Method: print
00387 //=======================================================================
00388 
00389   // required if data is present in this class
00390 void msm_shape_model_builder::print_summary(vcl_ostream& os) const
00391 {
00392   os<<'\n'<<vsl_indent()<<"aligner: ";
00393   if (aligner_.isDefined()) os<<aligner_; else os<<"-";
00394   os<<vsl_indent()<< "param_limiter: ";
00395   if (param_limiter_.isDefined())
00396     os<<param_limiter_; else os<<"-";
00397   os<<vsl_indent()<<"min_modes: "<<min_modes_<<'\n'
00398     <<vsl_indent()<<"max_modes: "<<max_modes_<<'\n'
00399     <<vsl_indent()<<"var_prop: "<<var_prop_;
00400 }
00401 
00402 //=======================================================================
00403 // Method: save
00404 //=======================================================================
00405 
00406   // required if data is present in this class
00407 void msm_shape_model_builder::b_write(vsl_b_ostream& bfs) const
00408 {
00409   vsl_b_write(bfs,version_no());
00410   vsl_b_write(bfs,aligner_);
00411   vsl_b_write(bfs,param_limiter_);
00412   vsl_b_write(bfs,min_modes_);
00413   vsl_b_write(bfs,max_modes_);
00414   vsl_b_write(bfs,var_prop_);
00415 }
00416 
00417 //=======================================================================
00418 // Method: load
00419 //=======================================================================
00420 
00421   // required if data is present in this class
00422 void msm_shape_model_builder::b_read(vsl_b_istream& bfs)
00423 {
00424   short version;
00425   vsl_b_read(bfs,version);
00426   switch (version)
00427   {
00428     case (1):
00429       vsl_b_read(bfs,aligner_);
00430       vsl_b_read(bfs,param_limiter_);
00431       vsl_b_read(bfs,min_modes_);
00432       vsl_b_read(bfs,max_modes_);
00433       vsl_b_read(bfs,var_prop_);
00434       break;
00435     default:
00436       vcl_cerr << "msm_shape_model_builder::b_read() :\n"
00437                << "Unexpected version number " << version << vcl_endl;
00438       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00439       return;
00440   }
00441 }
00442 
00443 
00444 //=======================================================================
00445 // Associated function: operator<<
00446 //=======================================================================
00447 
00448 void vsl_b_write(vsl_b_ostream& bfs, const msm_shape_model_builder& b)
00449 {
00450   b.b_write(bfs);
00451 }
00452 
00453 //=======================================================================
00454 // Associated function: operator>>
00455 //=======================================================================
00456 
00457 void vsl_b_read(vsl_b_istream& bfs, msm_shape_model_builder& b)
00458 {
00459   b.b_read(bfs);
00460 }
00461 
00462 //=======================================================================
00463 // Associated function: operator<<
00464 //=======================================================================
00465 
00466 vcl_ostream& operator<<(vcl_ostream& os,const msm_shape_model_builder& b)
00467 {
00468   os << b.is_a() << ": ";
00469   vsl_indent_inc(os);
00470   b.print_summary(os);
00471   vsl_indent_dec(os);
00472   return os;
00473 }
00474 
00475 //: Stream output operator for class reference
00476 void vsl_print_summary(vcl_ostream& os,const msm_shape_model_builder& b)
00477 {
00478  os << b;
00479 }