contrib/mul/mfpf/mfpf_pose_predictor_builder.cxx
Go to the documentation of this file.
00001 #include "mfpf_pose_predictor_builder.h"
00002 //:
00003 // \file
00004 // \brief Trains regressor in an mfpf_pose_predictor
00005 // \author Tim Cootes
00006 
00007 #include <vsl/vsl_binary_loader.h>
00008 #include <vcl_cmath.h>
00009 #include <vcl_algorithm.h>
00010 #include <vcl_cassert.h>
00011 #include <vnl/algo/vnl_svd.h>
00012 
00013 #include <vsl/vsl_indent.h>
00014 
00015 //=======================================================================
00016 // Dflt ctor
00017 //=======================================================================
00018 
00019 mfpf_pose_predictor_builder::mfpf_pose_predictor_builder()
00020 {
00021   set_defaults();
00022 }
00023 
00024 //: Define default values
00025 void mfpf_pose_predictor_builder::set_defaults()
00026 {
00027   n_per_eg_=25;
00028   rand_.reseed(57392);
00029 }
00030 
00031 //: Define number of samples per training image
00032 void mfpf_pose_predictor_builder::set_n_per_eg(unsigned n)
00033 {
00034   n_per_eg_=n;
00035 }
00036 
00037 
00038 //=======================================================================
00039 // Destructor
00040 //=======================================================================
00041 
00042 mfpf_pose_predictor_builder::~mfpf_pose_predictor_builder()
00043 {
00044 }
00045 
00046 //: Define sampling region and method
00047 //  Supplied predictor is partially initialised
00048 void mfpf_pose_predictor_builder::set_sampling(const mfpf_pose_predictor& pp)
00049 {
00050   sampler_=pp;
00051 }
00052 
00053 //: Initialise building
00054 // Must be called before any calls to add_example(...)
00055 void mfpf_pose_predictor_builder::clear(unsigned n_egs)
00056 {
00057   // Assume one plane image at first
00058   samples_.set_size(n_egs*n_per_eg_,sampler_.n_pixels()+1);
00059 
00060   unsigned nd=2;
00061   switch (sampler_.pose_type())
00062   {
00063     case translation: nd=2; break;
00064     case rigid:       nd=3; break;
00065     case zoom:        nd=3; break;
00066     case similarity:  nd=4; break;
00067     default: assert(!"Unknown pose_type"); break;
00068   }
00069 
00070   poses_.set_size(n_egs*n_per_eg_,nd);
00071 
00072   ci_ = 0;
00073 }
00074 
00075 //: Add one example to the model
00076 void mfpf_pose_predictor_builder::add_example(
00077                           const vimt_image_2d_of<float>& image,
00078                           const mfpf_pose& pose0)
00079 {
00080   double max_disp = 0.5*sampler_.radius()*sampler_.step_size();
00081   vnl_vector<double> vec;
00082   mfpf_pose dpose;
00083   for (unsigned i=0;i<n_per_eg_;++i,++ci_)
00084   {
00085   // Initial version hard coded for translation
00086 
00087     double dx = 0;
00088     double dy = 0;
00089     double s=0.0;
00090     double A=0.0;
00091     if (i>0)
00092     {
00093       dx=rand_.drand64(-max_disp,max_disp);
00094       dy=rand_.drand64(-max_disp,max_disp);
00095     }
00096     poses_(ci_,0)=dx;
00097     poses_(ci_,1)=dy;
00098 
00099     double max_dA_ = 0.25;  // radians
00100     double max_ds = vcl_log(1.2);
00101 
00102     switch (sampler_.pose_type())
00103     {
00104       case translation:
00105         break;
00106       case rigid:
00107         if (i>0) A = rand_.drand64(-max_dA_,max_dA_);
00108         poses_(ci_,2)=A;
00109         break;
00110       case zoom:
00111         if (i>0) s = rand_.drand64(-max_ds,max_ds);
00112         poses_(ci_,2)=s;
00113         break;
00114       case similarity:
00115         if (i>0)
00116         {
00117           A = rand_.drand64(-max_dA_,max_dA_);
00118           s = rand_.drand64(-max_ds,max_ds);
00119         }
00120         poses_(ci_,2)=s;
00121         poses_(ci_,3)=A;
00122         break;
00123       default:
00124         assert(!"Unknown pose_type");
00125         break;
00126     }
00127 
00128     dpose=mfpf_pose(dx,dy,vcl_exp(s)*vcl_cos(A),
00129                           vcl_exp(s)*vcl_sin(A));
00130 
00131     mfpf_pose pose = pose0*dpose;
00132     sampler_.get_sample_vector(image,pose.p(),pose.u(),vec);
00133     assert(vec.size()+1==samples_.cols());
00134     samples_(ci_,0)=1;
00135     for (unsigned j=0;j<vec.size();++j) samples_(ci_,1+j)=vec[j];
00136   }
00137 }
00138 
00139 //: Build object from the data supplied in add_example()
00140 void mfpf_pose_predictor_builder::build(mfpf_pose_predictor& p)
00141 {
00142   unsigned nv = poses_.cols();
00143   vnl_svd<double> svd(samples_);
00144 
00145   // Need to solve samples_*R=poses_ for R
00146   //                 (ns * np)(np*nv) = (ns * nv)
00147 
00148   unsigned n_samples = samples_.rows();
00149   if (n_samples>3*samples_.cols())
00150   {
00151     // Lots more samples than pixels
00152     vnl_matrix<double> R(nv,samples_.cols()-1);
00153     vnl_vector<double> r0(nv);
00154     for (unsigned i=0;i<nv;++i)
00155     {
00156       // Inefficient:
00157       vnl_vector<double> r = svd.solve(poses_.get_column(i));
00158       r0[i]=r[0];
00159       for (unsigned j=1;j<r.size();++j) R(i,j-1)=r[j];
00160     }
00161     p=sampler_;  // Define sampling
00162     p.set_predictor(R,r0);  // Define learned predictor
00163   }
00164   else
00165   {
00166     // May not be enough samples to train properly, so
00167     // use a reduced dimensional subspace
00168     unsigned rank = vcl_max(nv,unsigned(samples_.cols()/3));
00169     vnl_matrix<double> R1 = svd.pinverse(rank)*poses_;
00170     unsigned np=R1.rows();
00171     p=sampler_;  // Define sampling
00172     // Define learned predictor
00173     p.set_predictor(R1.extract(np-1,nv,1,0).transpose(),R1.get_row(0));
00174   }
00175 }
00176 
00177 
00178 //=======================================================================
00179 // Method: is_a
00180 //=======================================================================
00181 
00182 vcl_string mfpf_pose_predictor_builder::is_a() const
00183 {
00184   return vcl_string("mfpf_pose_predictor_builder");
00185 }
00186 
00187 //: Create a copy on the heap and return base class pointer
00188 mfpf_pose_predictor_builder* mfpf_pose_predictor_builder::clone() const
00189 {
00190   return new mfpf_pose_predictor_builder(*this);
00191 }
00192 
00193 //=======================================================================
00194 // Method: print
00195 //=======================================================================
00196 
00197 void mfpf_pose_predictor_builder::print_summary(vcl_ostream& os) const
00198 {
00199   os << "{  sampler: "<<sampler_ << '\n'
00200      << vsl_indent() << '}';
00201 }
00202 
00203 short mfpf_pose_predictor_builder::version_no() const
00204 {
00205   return 1;
00206 }
00207 
00208 
00209 void mfpf_pose_predictor_builder::b_write(vsl_b_ostream& bfs) const
00210 {
00211   vsl_b_write(bfs,version_no());
00212   vsl_b_write(bfs,sampler_);
00213   vsl_b_write(bfs,n_per_eg_);
00214 }
00215 
00216 //=======================================================================
00217 // Method: load
00218 //=======================================================================
00219 
00220 void mfpf_pose_predictor_builder::b_read(vsl_b_istream& bfs)
00221 {
00222   if (!bfs) return;
00223   short version;
00224   vsl_b_read(bfs,version);
00225   switch (version)
00226   {
00227     case (1):
00228       vsl_b_read(bfs,sampler_);
00229       vsl_b_read(bfs,n_per_eg_);
00230       break;
00231     default:
00232       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00233                << "           Unknown version number "<< version << vcl_endl;
00234       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00235       return;
00236   }
00237 }
00238 
00239 //: Test equality
00240 bool mfpf_pose_predictor_builder::operator==(const mfpf_pose_predictor_builder& nc) const
00241 {
00242   if (!(sampler_==nc.sampler_)) return false;
00243   if (n_per_eg_!=nc.n_per_eg_) return false;
00244   return true;
00245 }
00246 
00247 //=======================================================================
00248 // Associated function: operator<<
00249 //=======================================================================
00250 
00251 void vsl_b_write(vsl_b_ostream& bfs, const mfpf_pose_predictor_builder& b)
00252 {
00253     b.b_write(bfs);
00254 }
00255 
00256 //=======================================================================
00257 // Associated function: operator>>
00258 //=======================================================================
00259 
00260 void vsl_b_read(vsl_b_istream& bfs, mfpf_pose_predictor_builder& b)
00261 {
00262     b.b_read(bfs);
00263 }
00264 
00265 //=======================================================================
00266 // Associated function: operator<<
00267 //=======================================================================
00268 
00269 vcl_ostream& operator<<(vcl_ostream& os,const mfpf_pose_predictor_builder& b)
00270 {
00271   os << b.is_a() << ": ";
00272   vsl_indent_inc(os);
00273   b.print_summary(os);
00274   vsl_indent_dec(os);
00275   return os;
00276 }
00277