00001 #include "mcal_general_ca.h"
00002
00003
00004
00005
00006
00007 #include <vcl_cstdlib.h>
00008 #include <vcl_string.h>
00009 #include <vcl_sstream.h>
00010
00011 #include <vsl/vsl_indent.h>
00012 #include <mbl/mbl_matxvec.h>
00013 #include <vcl_cmath.h>
00014 #include <vcl_vector.h>
00015 #include <vsl/vsl_binary_io.h>
00016 #include <mbl/mbl_parse_block.h>
00017 #include <mbl/mbl_read_props.h>
00018 #include <vul/vul_string.h>
00019 #include <mbl/mbl_exception.h>
00020 #include <vnl/algo/vnl_brent_minimizer.h>
00021
00022
00023
00024
00025
00026 mcal_general_ca::mcal_general_ca()
00027 {
00028 set_defaults();
00029 }
00030
00031
00032 void mcal_general_ca::set(const mcal_component_analyzer& initial_ca,
00033 const mcal_single_basis_cost& basis_cost)
00034 {
00035 initial_ca_ = initial_ca;
00036 basis_cost_ = basis_cost;
00037 }
00038
00039 void mcal_general_ca::set_defaults()
00040 {
00041 max_passes_ = 50;
00042 move_thresh_=1e-4;
00043 }
00044
00045
00046
00047
00048
00049 mcal_general_ca::~mcal_general_ca()
00050 {
00051 }
00052
00053 class mcal_pair_cost1 : public vnl_cost_function
00054 {
00055 private:
00056 const vnl_vector<double>& proj1_;
00057 const vnl_vector<double>& proj2_;
00058 const vnl_vector<double>& mode1_;
00059 const vnl_vector<double>& mode2_;
00060 mcal_single_basis_cost& cost_;
00061 vnl_vector<double> p1,p2,m1,m2;
00062 public:
00063 mcal_pair_cost1(const vnl_vector<double>& proj1,
00064 const vnl_vector<double>& proj2,
00065 const vnl_vector<double>& mode1,
00066 const vnl_vector<double>& mode2,
00067 mcal_single_basis_cost& cost)
00068 : vnl_cost_function(1), proj1_(proj1),proj2_(proj2),mode1_(mode1),mode2_(mode2),cost_(cost) {}
00069
00070 double f(const vnl_vector<double>& x);
00071 };
00072
00073 double mcal_pair_cost1::f(const vnl_vector<double>& x)
00074 {
00075 double sinA = vcl_sin(x[0]);
00076 double cosA = vcl_cos(x[0]);
00077
00078
00079 m1 = cosA*mode1_ + sinA*mode2_;
00080 m2 = cosA*mode2_ - sinA*mode1_;
00081
00082
00083 p1 = cosA*proj1_ + sinA*proj2_;
00084 p2 = cosA*proj2_ - sinA*proj1_;
00085
00086 double sum = cost_.cost(m1,p1) + cost_.cost(m2,p2);
00087 return sum;
00088 }
00089
00090
00091 class mcal_pair_cost2 : public vnl_cost_function
00092 {
00093 private:
00094 vnl_matrix<double> S_;
00095 const vnl_vector<double>& mode1_;
00096 const vnl_vector<double>& mode2_;
00097 mcal_single_basis_cost& cost_;
00098 vnl_vector<double> m1,m2;
00099 public:
00100 mcal_pair_cost2(const vnl_matrix<double>& S,
00101 const vnl_vector<double>& mode1,
00102 const vnl_vector<double>& mode2,
00103 mcal_single_basis_cost& cost)
00104 : vnl_cost_function(1), S_(S),
00105 mode1_(mode1),mode2_(mode2),cost_(cost) {}
00106
00107 mcal_pair_cost2(const vnl_vector<double>& proj1,
00108 const vnl_vector<double>& proj2,
00109 const vnl_vector<double>& mode1,
00110 const vnl_vector<double>& mode2,
00111 mcal_single_basis_cost& cost);
00112
00113 double f(const vnl_vector<double>& x);
00114
00115 void covar(const vnl_vector<double>& p1,
00116 const vnl_vector<double>& p2,
00117 vnl_matrix<double>& S);
00118 };
00119
00120 mcal_pair_cost2::mcal_pair_cost2(const vnl_vector<double>& proj1,
00121 const vnl_vector<double>& proj2,
00122 const vnl_vector<double>& mode1,
00123 const vnl_vector<double>& mode2,
00124 mcal_single_basis_cost& cost)
00125 : vnl_cost_function(1),mode1_(mode1),mode2_(mode2),cost_(cost)
00126 {
00127 covar(proj1,proj2,S_);
00128 }
00129
00130
00131 double mcal_pair_cost2::f(const vnl_vector<double>& x)
00132 {
00133 double sinA = vcl_sin(x[0]);
00134 double cosA = vcl_cos(x[0]);
00135
00136
00137 m1 = cosA*mode1_ + sinA*mode2_;
00138 m2 = cosA*mode2_ - sinA*mode1_;
00139
00140
00141 vnl_matrix<double> R(2,2);
00142 R(0,0)=cosA; R(0,1) = sinA;
00143 R(1,0)=-sinA; R(1,1) = cosA;
00144
00145 vnl_matrix<double> SA = R*S_*R.transpose();
00146
00147 double c1 = cost_.cost_from_variance(m1,SA(0,0));
00148 double c2 = cost_.cost_from_variance(m2,SA(1,1));
00149 return c1+c2;
00150 }
00151
00152 void mcal_pair_cost2::covar(const vnl_vector<double>& p1,
00153 const vnl_vector<double>& p2,
00154 vnl_matrix<double>& S)
00155 {
00156 S.set_size(2,2);
00157 S(0,0) = dot_product(p1,p1)/p1.size();
00158 S(1,1) = dot_product(p2,p2)/p1.size();
00159 S(0,1) = dot_product(p1,p2)/p1.size();
00160 S(1,0) = S(0,1);
00161 }
00162
00163
00164
00165 double mcal_general_ca::optimise_mode_pair(vnl_vector<double>& proj1,
00166 vnl_vector<double>& proj2,
00167 vnl_vector<double>& mode1,
00168 vnl_vector<double>& mode2)
00169 {
00170 vnl_cost_function *cost_fn;
00171 if (basis_cost().can_use_variance())
00172 {
00173
00174 cost_fn = new mcal_pair_cost2(proj1,proj2,mode1,mode2,basis_cost());
00175 }
00176 else
00177 {
00178
00179 cost_fn = new mcal_pair_cost1(proj1,proj2,mode1,mode2,basis_cost());
00180 }
00181
00182 vnl_brent_minimizer brent1(*cost_fn);
00183
00184
00185
00186
00187
00188 double A = brent1.minimize(0.0);
00189
00190
00191 delete cost_fn;
00192
00193 if (A==0.0) return 0.0;
00194
00195
00196 double sinA = vcl_sin(A);
00197 double cosA = vcl_cos(A);
00198
00199 vnl_vector<double> m1=mode1,m2=mode2;
00200 vnl_vector<double> p1=proj1,p2=proj2;
00201
00202
00203 mode1 = cosA*m1 + sinA*m2;
00204 mode2 = cosA*m2 - sinA*m1;
00205
00206
00207 proj1 = cosA*p1 + sinA*p2;
00208 proj2 = cosA*p2 - sinA*p1;
00209
00210 return vcl_fabs(A);
00211 }
00212
00213
00214 double mcal_general_ca::optimise_one_pass(vcl_vector<vnl_vector<double> >& proj,
00215 vnl_matrix<double>& modes)
00216 {
00217 unsigned n_modes = modes.cols();
00218 double move_sum=0.0;
00219 for (unsigned i=1;i<n_modes;++i)
00220 {
00221 vnl_vector<double> mode1 = modes.get_column(i);
00222 for (unsigned j=0;j<i;++j)
00223 {
00224 vnl_vector<double> mode2 = modes.get_column(j);
00225 move_sum += optimise_mode_pair(proj[i],proj[j],mode1,mode2);
00226 modes.set_column(j,mode2);
00227 }
00228 modes.set_column(i,mode1);
00229 }
00230 return move_sum;
00231 }
00232
00233
00234
00235 void mcal_general_ca::compute_projections(mbl_data_wrapper<vnl_vector<double> >& data,
00236 const vnl_vector<double>& mean,
00237 vnl_matrix<double>& modes,
00238 vcl_vector<vnl_vector<double> >& proj)
00239 {
00240
00241 unsigned n_modes = modes.cols();
00242 unsigned n_egs = data.size();
00243 proj.resize(n_modes);
00244 for (unsigned j=0;j<n_modes;++j) { proj[j].set_size(n_egs); }
00245 vnl_vector<double> b(n_modes);
00246 vnl_vector<double> dx;
00247 data.reset();
00248 for (unsigned i=0;i<n_egs;++i,data.next())
00249 {
00250 dx=data.current()-mean;
00251 mbl_matxvec_prod_vm(dx,modes,b);
00252 for (unsigned j=0;j<n_modes;++j) proj[j][i]=b[j];
00253 }
00254 }
00255
00256
00257 void mcal_general_ca::optimise_about_mean(mbl_data_wrapper<vnl_vector<double> >& data,
00258 const vnl_vector<double>& mean,
00259 vnl_matrix<double>& modes,
00260 vnl_vector<double>& mode_var)
00261 {
00262
00263 unsigned n_modes = mode_var.size();
00264 unsigned n_egs = data.size();
00265 vcl_vector<vnl_vector<double> > proj(n_modes);
00266 compute_projections(data,mean,modes,proj);
00267
00268
00269 for (unsigned i=0;i<max_passes_;++i)
00270 {
00271 if (optimise_one_pass(proj,modes)<move_thresh_) break;
00272 }
00273
00274
00275 compute_projections(data,mean,modes,proj);
00276 mode_var.set_size(n_modes);
00277 for (unsigned j=0;j<n_modes;++j)
00278 mode_var[j]=proj[j].squared_magnitude()/n_egs;
00279 }
00280
00281
00282
00283
00284 void mcal_general_ca::build_about_mean(mbl_data_wrapper<vnl_vector<double> >& data,
00285 const vnl_vector<double>& mean,
00286 vnl_matrix<double>& modes,
00287 vnl_vector<double>& mode_var)
00288 {
00289 if (data.size()==0)
00290 {
00291 vcl_cerr<<"mcal_general_ca::build_about_mean() No samples supplied.\n";
00292 vcl_abort();
00293 }
00294
00295 data.reset();
00296
00297 if (data.current().size()==0)
00298 {
00299 vcl_cerr<<"mcal_general_ca::build_about_mean()\n"
00300 <<"Warning: Samples claim to have zero dimensions.\n"
00301 <<"Constructing empty model.\n";
00302
00303 modes.set_size(0,0);
00304 mode_var.set_size(0);
00305 return;
00306 }
00307
00308
00309 initial_ca().build_about_mean(data,mean,modes,mode_var);
00310
00311
00312 optimise_about_mean(data,mean,modes,mode_var);
00313 }
00314
00315
00316
00317
00318
00319
00320 vcl_string mcal_general_ca::is_a() const
00321 {
00322 return vcl_string("mcal_general_ca");
00323 }
00324
00325
00326
00327
00328
00329 short mcal_general_ca::version_no() const
00330 {
00331 return 1;
00332 }
00333
00334
00335
00336
00337
00338 mcal_component_analyzer* mcal_general_ca::clone() const
00339 {
00340 return new mcal_general_ca(*this);
00341 }
00342
00343
00344
00345
00346
00347 void mcal_general_ca::print_summary(vcl_ostream& os) const
00348 {
00349 vsl_indent_inc(os);
00350 os<<"{\n"
00351 <<vsl_indent()<<"initial_ca: "<<initial_ca_<<'\n'
00352 <<vsl_indent()<<"basis_cost: "<<basis_cost_<<'\n'
00353 <<vsl_indent()<<"} ";
00354 vsl_indent_dec(os);
00355 }
00356
00357
00358
00359
00360
00361 void mcal_general_ca::b_write(vsl_b_ostream& bfs) const
00362 {
00363 vsl_b_write(bfs,version_no());
00364 vsl_b_write(bfs,initial_ca_);
00365 vsl_b_write(bfs,basis_cost_);
00366 vsl_b_write(bfs,max_passes_);
00367 vsl_b_write(bfs,move_thresh_);
00368 }
00369
00370
00371
00372
00373
00374 void mcal_general_ca::b_read(vsl_b_istream& bfs)
00375 {
00376 short version;
00377 vsl_b_read(bfs,version);
00378 switch (version)
00379 {
00380 case 1:
00381 vsl_b_read(bfs,initial_ca_);
00382 vsl_b_read(bfs,basis_cost_);
00383 vsl_b_read(bfs,max_passes_);
00384 vsl_b_read(bfs,move_thresh_);
00385 break;
00386 default:
00387 vcl_cerr << "mcal_general_ca::b_read()\n"
00388 << "Unexpected version number " << version << vcl_endl;
00389 vcl_abort();
00390 }
00391 }
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403 void mcal_general_ca::config_from_stream(vcl_istream & is)
00404 {
00405 vcl_string s = mbl_parse_block(is);
00406
00407 vcl_istringstream ss(s);
00408 mbl_read_props_type props = mbl_read_props_ws(ss);
00409
00410 set_defaults();
00411
00412 if (props.find("initial_ca")!=props.end())
00413 {
00414 vcl_istringstream ss(props["initial_ca"]);
00415 vcl_auto_ptr<mcal_component_analyzer> ca;
00416 ca=mcal_component_analyzer::create_from_stream(ss);
00417 initial_ca_ = *ca;
00418
00419 props.erase("initial_ca");
00420 }
00421
00422 if (props.find("basis_cost")!=props.end())
00423 {
00424 vcl_istringstream ss(props["basis_cost"]);
00425 vcl_auto_ptr<mcal_single_basis_cost> bc;
00426 bc=mcal_single_basis_cost::create_from_stream(ss);
00427 basis_cost_ = *bc;
00428
00429 props.erase("basis_cost");
00430 }
00431
00432 if (props.find("max_passes")!=props.end())
00433 {
00434 max_passes_ = vul_string_atoi(props["max_passes"]);
00435 props.erase("max_passes");
00436 }
00437 if (props.find("move_thresh")!=props.end())
00438 {
00439 move_thresh_ = vul_string_atoi(props["move_thresh"]);
00440 props.erase("move_thresh");
00441 }
00442
00443
00444 try
00445 {
00446 mbl_read_props_look_for_unused_props(
00447 "mcal_general_ca::config_from_stream", props);
00448 }
00449 catch(mbl_exception_unused_props &e)
00450 {
00451 throw mbl_exception_parse_error(e.what());
00452 }
00453 }