contrib/mul/mcal/mcal_general_ca.cxx
Go to the documentation of this file.
00001 #include "mcal_general_ca.h"
00002 //:
00003 // \file
00004 // \author Tim Cootes
00005 // \brief Class to perform general Component Analysis
00006 
00007 #include <vcl_cstdlib.h>
00008 #include <vcl_string.h>
00009 #include <vcl_sstream.h>
00010 
00011 #include <vsl/vsl_indent.h>
00012 #include <mbl/mbl_matxvec.h>
00013 #include <vcl_cmath.h>
00014 #include <vcl_vector.h>
00015 #include <vsl/vsl_binary_io.h>
00016 #include <mbl/mbl_parse_block.h>
00017 #include <mbl/mbl_read_props.h>
00018 #include <vul/vul_string.h>
00019 #include <mbl/mbl_exception.h>
00020 #include <vnl/algo/vnl_brent_minimizer.h>
00021 
00022 //=======================================================================
00023 // Constructors
00024 //=======================================================================
00025 
00026 mcal_general_ca::mcal_general_ca()
00027 {
00028   set_defaults();
00029 }
00030 
00031 //: Initialise, taking clones of supplied objects
00032 void mcal_general_ca::set(const mcal_component_analyzer& initial_ca,
00033                           const mcal_single_basis_cost& basis_cost)
00034 {
00035   initial_ca_ = initial_ca;
00036   basis_cost_ = basis_cost;
00037 }
00038 
00039 void mcal_general_ca::set_defaults()
00040 {
00041   max_passes_ = 50;
00042   move_thresh_=1e-4;
00043 }
00044 
00045 //=======================================================================
00046 // Destructor
00047 //=======================================================================
00048 
00049 mcal_general_ca::~mcal_general_ca()
00050 {
00051 }
00052 
00053 class mcal_pair_cost1 : public vnl_cost_function
00054 {
00055  private:
00056   const vnl_vector<double>& proj1_;
00057   const vnl_vector<double>& proj2_;
00058   const vnl_vector<double>& mode1_;
00059   const vnl_vector<double>& mode2_;
00060   mcal_single_basis_cost& cost_;
00061   vnl_vector<double> p1,p2,m1,m2;
00062  public:
00063   mcal_pair_cost1(const vnl_vector<double>& proj1,
00064                   const vnl_vector<double>& proj2,
00065                   const vnl_vector<double>& mode1,
00066                   const vnl_vector<double>& mode2,
00067                   mcal_single_basis_cost& cost)
00068   : vnl_cost_function(1), proj1_(proj1),proj2_(proj2),mode1_(mode1),mode2_(mode2),cost_(cost) {}
00069 
00070   double f(const vnl_vector<double>& x);
00071 };
00072 
00073 double mcal_pair_cost1::f(const vnl_vector<double>& x)
00074 {
00075   double sinA = vcl_sin(x[0]);
00076   double cosA = vcl_cos(x[0]);
00077 
00078   // Rotate axes by A=x[0]
00079   m1 = cosA*mode1_ + sinA*mode2_;
00080   m2 = cosA*mode2_ - sinA*mode1_;
00081 
00082   // Rotate projections equivalently
00083   p1 = cosA*proj1_ + sinA*proj2_;
00084   p2 = cosA*proj2_ - sinA*proj1_;
00085 
00086   double sum = cost_.cost(m1,p1) + cost_.cost(m2,p2);
00087   return sum;
00088 }
00089 
00090 //: Cost, assuming it can be evaluated from variance of projection
00091 class mcal_pair_cost2 : public vnl_cost_function
00092 {
00093  private:
00094   vnl_matrix<double> S_;
00095   const vnl_vector<double>& mode1_;
00096   const vnl_vector<double>& mode2_;
00097   mcal_single_basis_cost& cost_;
00098   vnl_vector<double> m1,m2;
00099  public:
00100   mcal_pair_cost2(const vnl_matrix<double>& S,
00101                   const vnl_vector<double>& mode1,
00102                   const vnl_vector<double>& mode2,
00103                   mcal_single_basis_cost& cost)
00104   : vnl_cost_function(1), S_(S),
00105     mode1_(mode1),mode2_(mode2),cost_(cost) {}
00106 
00107   mcal_pair_cost2(const vnl_vector<double>& proj1,
00108                   const vnl_vector<double>& proj2,
00109                   const vnl_vector<double>& mode1,
00110                   const vnl_vector<double>& mode2,
00111                   mcal_single_basis_cost& cost);
00112 
00113   double f(const vnl_vector<double>& x);
00114 
00115   void covar(const vnl_vector<double>& p1,
00116              const vnl_vector<double>& p2,
00117              vnl_matrix<double>& S);
00118 };
00119 
00120 mcal_pair_cost2::mcal_pair_cost2(const vnl_vector<double>& proj1,
00121                                  const vnl_vector<double>& proj2,
00122                                  const vnl_vector<double>& mode1,
00123                                  const vnl_vector<double>& mode2,
00124                                  mcal_single_basis_cost& cost)
00125   : vnl_cost_function(1),mode1_(mode1),mode2_(mode2),cost_(cost)
00126 {
00127   covar(proj1,proj2,S_);
00128 }
00129 
00130 
00131 double mcal_pair_cost2::f(const vnl_vector<double>& x)
00132 {
00133   double sinA = vcl_sin(x[0]);
00134   double cosA = vcl_cos(x[0]);
00135 
00136   // Rotate axes by A=x[0]
00137   m1 = cosA*mode1_ + sinA*mode2_;
00138   m2 = cosA*mode2_ - sinA*mode1_;
00139 
00140   // Rotate covariance equivalently
00141   vnl_matrix<double> R(2,2);
00142   R(0,0)=cosA;  R(0,1) = sinA;
00143   R(1,0)=-sinA; R(1,1) = cosA;
00144 
00145   vnl_matrix<double> SA = R*S_*R.transpose();
00146 
00147   double c1 = cost_.cost_from_variance(m1,SA(0,0));
00148   double c2 = cost_.cost_from_variance(m2,SA(1,1));
00149   return c1+c2;
00150 }
00151 
00152 void mcal_pair_cost2::covar(const vnl_vector<double>& p1,
00153                             const vnl_vector<double>& p2,
00154                             vnl_matrix<double>& S)
00155 {
00156   S.set_size(2,2);
00157   S(0,0) = dot_product(p1,p1)/p1.size();
00158   S(1,1) = dot_product(p2,p2)/p1.size();
00159   S(0,1) = dot_product(p1,p2)/p1.size();
00160   S(1,0) = S(0,1);
00161 }
00162 
00163 
00164 //: Optimise the mode vectors so as to minimise the cost function
00165 double mcal_general_ca::optimise_mode_pair(vnl_vector<double>& proj1,
00166                                            vnl_vector<double>& proj2,
00167                                            vnl_vector<double>& mode1,
00168                                            vnl_vector<double>& mode2)
00169 {
00170   vnl_cost_function *cost_fn;
00171   if (basis_cost().can_use_variance())
00172   {
00173     // Use more efficient cost evaluation
00174     cost_fn = new mcal_pair_cost2(proj1,proj2,mode1,mode2,basis_cost());
00175   }
00176   else
00177   {
00178     // Use cost which explicitly rotates projection data
00179     cost_fn = new mcal_pair_cost1(proj1,proj2,mode1,mode2,basis_cost());
00180   }
00181 
00182   vnl_brent_minimizer brent1(*cost_fn);
00183 
00184   // Note that rotation should be in range [0,pi/2)
00185   // There is fourfold cyclic symmetry - cost(A)==cost(A+pi/2)
00186   // We could perform an initial exhaustive search, then use
00187   // A=minimize_given_bounds(a,b,c)
00188   double A = brent1.minimize(0.0);
00189 
00190   // Tidy up
00191   delete cost_fn;
00192 
00193   if (A==0.0) return 0.0;
00194 
00195   // Apply rotation
00196   double sinA = vcl_sin(A);
00197   double cosA = vcl_cos(A);
00198 
00199   vnl_vector<double> m1=mode1,m2=mode2;
00200   vnl_vector<double> p1=proj1,p2=proj2;
00201 
00202   // Rotate axes by A=x[0]
00203   mode1 = cosA*m1 + sinA*m2;
00204   mode2 = cosA*m2 - sinA*m1;
00205 
00206   // Rotate projections equivalently
00207   proj1 = cosA*p1 + sinA*p2;
00208   proj2 = cosA*p2 - sinA*p1;
00209 
00210   return vcl_fabs(A);
00211 }
00212 
00213 //: Optimise the mode vectors so as to minimise the cost function
00214 double mcal_general_ca::optimise_one_pass(vcl_vector<vnl_vector<double> >& proj,
00215                                           vnl_matrix<double>& modes)
00216 {
00217   unsigned n_modes = modes.cols();
00218   double move_sum=0.0;
00219   for (unsigned i=1;i<n_modes;++i)
00220   {
00221     vnl_vector<double> mode1 = modes.get_column(i);
00222     for (unsigned j=0;j<i;++j)
00223     {
00224       vnl_vector<double> mode2 = modes.get_column(j);
00225       move_sum += optimise_mode_pair(proj[i],proj[j],mode1,mode2);
00226       modes.set_column(j,mode2);
00227     }
00228     modes.set_column(i,mode1);
00229   }
00230   return move_sum;
00231 }
00232 
00233 //: Compute projections onto each mode
00234 //  proj[j][i] is the projection of the i-th data sample onto the j-th mode
00235 void mcal_general_ca::compute_projections(mbl_data_wrapper<vnl_vector<double> >& data,
00236                                           const vnl_vector<double>& mean,
00237                                           vnl_matrix<double>& modes,
00238                                           vcl_vector<vnl_vector<double> >& proj)
00239 {
00240   // Compute projection of data onto each mode
00241   unsigned n_modes = modes.cols();
00242   unsigned n_egs   = data.size();
00243   proj.resize(n_modes);
00244   for (unsigned j=0;j<n_modes;++j) { proj[j].set_size(n_egs); }
00245   vnl_vector<double> b(n_modes);
00246   vnl_vector<double> dx;
00247   data.reset();
00248   for (unsigned i=0;i<n_egs;++i,data.next())
00249   {
00250     dx=data.current()-mean;
00251     mbl_matxvec_prod_vm(dx,modes,b);
00252     for (unsigned j=0;j<n_modes;++j) proj[j][i]=b[j];
00253   }
00254 }
00255 
00256 //: Optimise the mode vectors so as to minimise the cost function
00257 void mcal_general_ca::optimise_about_mean(mbl_data_wrapper<vnl_vector<double> >& data,
00258                                           const vnl_vector<double>& mean,
00259                                           vnl_matrix<double>& modes,
00260                                           vnl_vector<double>& mode_var)
00261 {
00262   // Compute projection of data onto each mode
00263   unsigned n_modes = mode_var.size();
00264   unsigned n_egs   = data.size();
00265   vcl_vector<vnl_vector<double> > proj(n_modes);
00266   compute_projections(data,mean,modes,proj);
00267 
00268   // Perform multiple passes
00269   for (unsigned i=0;i<max_passes_;++i)
00270   {
00271     if (optimise_one_pass(proj,modes)<move_thresh_) break;
00272   }
00273 
00274   // Compute the variances on each mode
00275   compute_projections(data,mean,modes,proj);
00276   mode_var.set_size(n_modes);
00277   for (unsigned j=0;j<n_modes;++j)
00278     mode_var[j]=proj[j].squared_magnitude()/n_egs;
00279 }
00280 
00281 //: Compute modes of the supplied data relative to the supplied mean
00282 //  Model is x = mean + modes*b,  where b is a vector of weights on each mode.
00283 //  mode_var[i] gives the variance of the data projected onto that mode.
00284 void mcal_general_ca::build_about_mean(mbl_data_wrapper<vnl_vector<double> >& data,
00285                                        const vnl_vector<double>& mean,
00286                                        vnl_matrix<double>& modes,
00287                                        vnl_vector<double>& mode_var)
00288 {
00289   if (data.size()==0)
00290   {
00291     vcl_cerr<<"mcal_general_ca::build_about_mean() No samples supplied.\n";
00292     vcl_abort();
00293   }
00294 
00295   data.reset();
00296 
00297   if (data.current().size()==0)
00298   {
00299     vcl_cerr<<"mcal_general_ca::build_about_mean()\n"
00300             <<"Warning: Samples claim to have zero dimensions.\n"
00301             <<"Constructing empty model.\n";
00302 
00303     modes.set_size(0,0);
00304     mode_var.set_size(0);
00305     return;
00306   }
00307 
00308   // Compute initial approximation
00309   initial_ca().build_about_mean(data,mean,modes,mode_var);
00310 
00311   // Now perform optimisation
00312   optimise_about_mean(data,mean,modes,mode_var);
00313 }
00314 
00315 
00316 //=======================================================================
00317 // Method: is_a
00318 //=======================================================================
00319 
00320 vcl_string  mcal_general_ca::is_a() const
00321 {
00322   return vcl_string("mcal_general_ca");
00323 }
00324 
00325 //=======================================================================
00326 // Method: version_no
00327 //=======================================================================
00328 
00329 short mcal_general_ca::version_no() const
00330 {
00331   return 1;
00332 }
00333 
00334 //=======================================================================
00335 // Method: clone
00336 //=======================================================================
00337 
00338 mcal_component_analyzer* mcal_general_ca::clone() const
00339 {
00340   return new mcal_general_ca(*this);
00341 }
00342 
00343 //=======================================================================
00344 // Method: print
00345 //=======================================================================
00346 
00347 void mcal_general_ca::print_summary(vcl_ostream& os) const
00348 {
00349   vsl_indent_inc(os);
00350   os<<"{\n"
00351     <<vsl_indent()<<"initial_ca: "<<initial_ca_<<'\n'
00352     <<vsl_indent()<<"basis_cost: "<<basis_cost_<<'\n'
00353     <<vsl_indent()<<"} ";
00354   vsl_indent_dec(os);
00355 }
00356 
00357 //=======================================================================
00358 // Method: save
00359 //=======================================================================
00360 
00361 void mcal_general_ca::b_write(vsl_b_ostream& bfs) const
00362 {
00363   vsl_b_write(bfs,version_no());
00364   vsl_b_write(bfs,initial_ca_);
00365   vsl_b_write(bfs,basis_cost_);
00366   vsl_b_write(bfs,max_passes_);
00367   vsl_b_write(bfs,move_thresh_);
00368 }
00369 
00370 //=======================================================================
00371 // Method: load
00372 //=======================================================================
00373 
00374 void mcal_general_ca::b_read(vsl_b_istream& bfs)
00375 {
00376   short version;
00377   vsl_b_read(bfs,version);
00378   switch (version)
00379   {
00380     case 1:
00381       vsl_b_read(bfs,initial_ca_);
00382       vsl_b_read(bfs,basis_cost_);
00383       vsl_b_read(bfs,max_passes_);
00384       vsl_b_read(bfs,move_thresh_);
00385       break;
00386     default:
00387       vcl_cerr << "mcal_general_ca::b_read()\n"
00388                << "Unexpected version number " << version << vcl_endl;
00389       vcl_abort();
00390   }
00391 }
00392 
00393 //=======================================================================
00394 //: Read initialisation settings from a stream.
00395 // Parameters:
00396 // \verbatim
00397 // {
00398 //   initial_ca: mcal_pca { ... }
00399 //   basis_cost: mcal_sparse_basis_cost { alpha: 0.1 }
00400 // }
00401 // \endverbatim
00402 // \throw mbl_exception_parse_error if the parse fails.
00403 void mcal_general_ca::config_from_stream(vcl_istream & is)
00404 {
00405   vcl_string s = mbl_parse_block(is);
00406 
00407   vcl_istringstream ss(s);
00408   mbl_read_props_type props = mbl_read_props_ws(ss);
00409 
00410   set_defaults();
00411 
00412   if (props.find("initial_ca")!=props.end())
00413   {
00414     vcl_istringstream ss(props["initial_ca"]);
00415     vcl_auto_ptr<mcal_component_analyzer> ca;
00416     ca=mcal_component_analyzer::create_from_stream(ss);
00417     initial_ca_ = *ca;
00418 
00419     props.erase("initial_ca");
00420   }
00421 
00422   if (props.find("basis_cost")!=props.end())
00423   {
00424     vcl_istringstream ss(props["basis_cost"]);
00425     vcl_auto_ptr<mcal_single_basis_cost> bc;
00426     bc=mcal_single_basis_cost::create_from_stream(ss);
00427     basis_cost_ = *bc;
00428 
00429     props.erase("basis_cost");
00430   }
00431 
00432   if (props.find("max_passes")!=props.end())
00433   {
00434     max_passes_ = vul_string_atoi(props["max_passes"]);
00435     props.erase("max_passes");
00436   }
00437   if (props.find("move_thresh")!=props.end())
00438   {
00439     move_thresh_ = vul_string_atoi(props["move_thresh"]);
00440     props.erase("move_thresh");
00441   }
00442 
00443 
00444   try
00445   {
00446     mbl_read_props_look_for_unused_props(
00447           "mcal_general_ca::config_from_stream", props);
00448   }
00449   catch(mbl_exception_unused_props &e)
00450   {
00451     throw mbl_exception_parse_error(e.what());
00452   }
00453 }