contrib/mul/msm/msm_shape_perturber.cxx
Go to the documentation of this file.
00001 #include "msm_shape_perturber.h"
00002 #include <vcl_cassert.h>
00003 
00004 msm_shape_perturber::msm_shape_perturber()
00005 {
00006   rel_gauss_ = 0.5;
00007   n_pose_    = 0;
00008   n_params_  = 0;
00009 }
00010 
00011 void msm_shape_perturber::set_model( const msm_shape_model& in )
00012 {
00013   sm_ = in;
00014   sm_inst_.set_shape_model( sm_ );
00015   gt_inst_.set_shape_model( sm_ );
00016 
00017   n_pose_   = sm_inst_.pose().size();
00018   dpose_.set_size( n_pose_ );
00019 
00020   n_params_ = sm_inst_.params().size();
00021   dparams_.set_size( n_params_ );
00022 
00023   vcl_size_t np = n_pose_ + n_params_;
00024 
00025   all_.set_size( np );
00026   dall_.set_size( np );
00027   inv_dall_.set_size( np );
00028 }
00029 
00030 void msm_shape_perturber::perturb( const msm_points& pts )
00031 {
00032   assert( max_dpose_.size() != 0 );
00033   assert( max_dpose_.size() == n_pose_ );
00034 
00035   gt_inst_.fit_to_points( pts );
00036   sm_inst_.fit_to_points( pts );
00037 
00038   const msm_aligner& aligner = sm_.aligner();
00039 
00040   // Generate random pose displacement
00041   for ( unsigned i=0; i<n_pose_; ++i )
00042     dpose_[i] = random_value( rand_, max_dpose_[i], rel_gauss_ );
00043 
00044   vnl_vector<double> pose = aligner.compose( sm_inst_.pose(), dpose_ );
00045   sm_inst_.set_pose( pose );
00046 
00047   if ( max_dparams_.size() != 0 )
00048   {
00049     // Generate a random parameter offset
00050     for ( unsigned i=0; i<n_params_; ++i )
00051       dparams_[i] = random_value(rand_,max_dparams_[i],rel_gauss_);
00052 
00053     vnl_vector<double> p = gt_inst_.params() + dparams_;
00054 
00055     // If any of the max_dp are zero, then assume
00056     // displacement is from zero
00057     for ( unsigned i=0; i<n_params_; ++i )
00058       if ( max_dparams_[i]==0 )
00059         p[i] = 0.0;
00060 
00061     sm_inst_.set_params(p);
00062   }
00063   else
00064   {
00065     // a fair perturbation should be from the mean shape
00066     if ( n_params_ > 0 )
00067     {
00068       dparams_ = vnl_vector<double>(n_params_,0);
00069       sm_inst_.set_params(dparams_);
00070     }
00071   }
00072 
00073   inv_dpose_ = aligner.inverse( sm_inst_.pose() );
00074   inv_dpose_ = aligner.compose( inv_dpose_, gt_inst_.pose() );
00075 
00076   for ( unsigned i=0; i<n_pose_; ++i )
00077   {
00078     all_[i]       = sm_inst_.pose()[i];
00079     dall_[i]      = dpose_[i];
00080     inv_dall_[i]  = inv_dpose_[i];
00081   }
00082 
00083   for ( unsigned i=0; i<n_params_; ++i )
00084   {
00085     all_[n_pose_+i]       = sm_inst_.params()[i];
00086     dall_[n_pose_+i]      = dparams_[i];
00087     inv_dall_[n_pose_+i]  = gt_inst_.params()[i] - sm_inst_.params()[i];
00088   }
00089 }
00090 
00091 void msm_shape_perturber::set_max_dp( const vnl_vector<double>& max_dpose, const vnl_vector<double>& max_dparams )
00092 {
00093   max_dpose_   = max_dpose;
00094   max_dparams_ = max_dparams;
00095 }
00096 
00097 void msm_shape_perturber::set_seed( vcl_size_t s )
00098 {
00099   rand_.reseed( s );
00100 }
00101 
00102 void msm_shape_perturber::set_rel_gauss( double val )
00103 {
00104   rel_gauss_ = val;
00105 }
00106 
00107 const vnl_vector<double>& msm_shape_perturber::params() const
00108 {
00109   return sm_inst_.params();
00110 }
00111 
00112 const vnl_vector<double>& msm_shape_perturber::pose() const
00113 {
00114   return sm_inst_.pose();
00115 }
00116 
00117 const vnl_vector<double>& msm_shape_perturber::gt_params() const
00118 {
00119   return gt_inst_.params();
00120 }
00121 
00122 const vnl_vector<double>& msm_shape_perturber::gt_pose() const
00123 {
00124   return gt_inst_.pose();
00125 }
00126 
00127 const vnl_vector<double>& msm_shape_perturber::all() const
00128 {
00129   return all_;
00130 }
00131 
00132 const vnl_vector<double>& msm_shape_perturber::inv_d_params() const
00133 {
00134   return inv_dparams_;
00135 }
00136 
00137 const vnl_vector<double>& msm_shape_perturber::inv_d_pose() const
00138 {
00139   return inv_dpose_;
00140 }
00141 
00142 const vnl_vector<double>& msm_shape_perturber::inv_d_all() const
00143 {
00144   return inv_dall_;
00145 }
00146 
00147 const msm_points& msm_shape_perturber::points() const
00148 {
00149   return const_cast<msm_shape_instance&>(sm_inst_).points();
00150 }
00151 
00152 
00153 double msm_shape_perturber::trunc_normal_sample( vnl_random& rand1,
00154                                                  double sd, double max_d )
00155 {
00156   double s=max_d+1;
00157   while ( s < -max_d || s > max_d )
00158     s = sd * rand1.normal64();
00159   return s;
00160 }
00161 
00162 double msm_shape_perturber::random_value( vnl_random& rand,
00163                                           double max_v, double rel_gauss_sd )
00164 {
00165   if ( max_v == 0 )
00166     return 0.0;
00167   else if ( rel_gauss_sd == 0.0 )
00168     return max_v * rand.drand64(-1,1);
00169   else
00170     return max_v * trunc_normal_sample( rand, rel_gauss_sd, 1 );
00171 }
00172