contrib/mul/mmn/mmn_dp_solver.cxx
Go to the documentation of this file.
00001 #include "mmn_dp_solver.h"
00002 //:
00003 // \file
00004 // \brief Solve restricted class of Markov problems (trees and tri-trees)
00005 // \author Tim Cootes
00006 
00007 #include <mmn/mmn_order_cost.h>
00008 #include <mmn/mmn_graph_rep1.h>
00009 #include <vcl_cassert.h>
00010 #include <vcl_cstdlib.h>
00011 #include <vcl_sstream.h>
00012 
00013 #include <mbl/mbl_parse_block.h>
00014 #include <mbl/mbl_read_props.h>
00015 
00016 //: Default constructor
00017 mmn_dp_solver::mmn_dp_solver()
00018 {
00019 }
00020 
00021 //: Input the arcs that define the graph
00022 void mmn_dp_solver::set_arcs(unsigned num_nodes,
00023                              const vcl_vector<mmn_arc>& arcs)
00024 {
00025   // Copy in arcs, and ensure ordering v1<v2
00026   vcl_vector<mmn_arc> ordered_arcs(arcs.size());
00027   for (unsigned i=0;i<arcs.size();++i)
00028   {
00029     if (arcs[i].v1<arcs[i].v2)
00030       ordered_arcs[i]= arcs[i];
00031     else
00032       ordered_arcs[i]= mmn_arc(arcs[i].v2,arcs[i].v1);
00033   }
00034 
00035   mmn_graph_rep1 graph;
00036   graph.build(num_nodes,ordered_arcs);
00037   vcl_vector<mmn_dependancy> deps;
00038   if (!graph.compute_dependancies(deps,0))
00039   {
00040     vcl_cerr<<"Graph cannot be decomposed - too complex.\n"
00041             <<"Arc list: ";
00042     for (unsigned i=0;i<arcs.size();++i) vcl_cout<<arcs[i];
00043     vcl_cerr<<'\n';
00044     vcl_abort();
00045   }
00046 
00047   set_dependancies(deps,num_nodes,graph.max_n_arcs());
00048 }
00049 
00050 
00051 //: Index of root node
00052 unsigned mmn_dp_solver::root() const
00053 {
00054   if (deps_.size()==0) return 0;
00055   return deps_[deps_.size()-1].v1;
00056 }
00057 
00058 //: Define dependencies
00059 void mmn_dp_solver::set_dependancies(const vcl_vector<mmn_dependancy>& deps,
00060                                      unsigned n_nodes, unsigned max_n_arcs)
00061 {
00062   deps_ = deps;
00063   nc_.resize(n_nodes);
00064   pc_.resize(max_n_arcs);
00065   index1_.resize(n_nodes);
00066   index2_.resize(n_nodes);
00067 }
00068 
00069 void mmn_dp_solver::process_dep1(const mmn_dependancy& dep)
00070 {
00071   // dep->v0 depends on dep->v1 through arc dep->arc1
00072   const vnl_vector<double>& nc0 = nc_[dep.v0];
00073   vnl_vector<double>& nc1 = nc_[dep.v1];
00074   vnl_matrix<double>& p = pc_[dep.arc1];
00075   vcl_vector<unsigned>& i0 = index1_[dep.v0];
00076 
00077   // Check sizes of matrices
00078   if (dep.v0<dep.v1)
00079   {
00080     assert(p.rows()==nc0.size());
00081     assert(p.cols()==nc1.size());
00082   }
00083   else
00084   {
00085     if (p.rows()!=nc1.size())
00086     {
00087       vcl_cerr<<"p.rows()="<<p.rows()<<"p.cols()="<<p.cols()
00088               <<" nc0.size()="<<nc0.size()
00089               <<" nc1.size()="<<nc1.size()<<vcl_endl
00090               <<"dep: "<<dep<<vcl_endl;
00091     }
00092     assert(p.rows()==nc1.size());
00093     assert(p.cols()==nc0.size());
00094   }
00095 
00096   // Set i0[i1] to the optimal choice of node v0 if v1 is i1
00097   i0.resize(nc1.size());
00098   for (unsigned j=0;j<nc1.size();++j)
00099   {
00100     double min_v;
00101     unsigned best_i=0;
00102     if (dep.v0<dep.v1)
00103     {
00104       min_v=nc0[0]+p(0,j);
00105       for (unsigned i=1;i<nc0.size();++i)
00106       {
00107         double v=nc0[i]+p(i,j);
00108         if (v<min_v) { min_v=v; best_i=i; }
00109       }
00110     }
00111     else
00112     {
00113       min_v=nc0[0]+p(j,0);
00114       for (unsigned i=1;i<nc0.size();++i)
00115       {
00116         double v=nc0[i]+p(j,i);
00117         if (v<min_v) { min_v=v; best_i=i; }
00118       }
00119     }
00120     i0[j]=best_i;
00121     nc1[j]+=min_v;  // Update costs for node v1
00122   }
00123 }
00124 
00125 void mmn_dp_solver::process_dep2(const mmn_dependancy& dep)
00126 {
00127   // n_dep==2
00128   // dep->v0 depends on dep->v1 and dep->v2
00129   // dep->v0 depends on dep->v1 through arc dep->arc1
00130   const vnl_vector<double>& nc0 = nc_[dep.v0];
00131   const vnl_vector<double>& nc1 = nc_[dep.v1];
00132   const vnl_vector<double>& nc2 = nc_[dep.v2];
00133   const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00134   const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00135   vnl_matrix<double>& pa12 = pc_[dep.arc12];
00136   vnl_matrix<int>& ind0 = index2_[dep.v0];
00137 
00138   if (pa12.size()==0)
00139   {
00140     if (dep.v1<dep.v2)
00141       pa12.set_size(nc1.size(),nc2.size());
00142     else
00143       pa12.set_size(nc2.size(),nc1.size());
00144     pa12.fill(0.0);
00145   }
00146 
00147   // i0[i1,i2] to the optimal choice of node v0 if v1 is i1, v2 is i2
00148   ind0.set_size(nc1.size(),nc2.size());
00149 
00150   for (unsigned i1=0;i1<nc1.size();++i1)
00151   {
00152     vnl_vector<double> sum0(nc0);
00153     if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00154     else               sum0+=pa1.get_row(i1);
00155 
00156     for (unsigned i2=0;i2<nc2.size();++i2)
00157     {
00158       vnl_vector<double> sum(sum0);
00159       if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00160       else               sum+=pa2.get_row(i2);
00161 
00162       // sum[i] is the cost of choosing i, given (i1,i2)
00163       // Select minimum
00164       unsigned best_i=0;
00165       double min_v=sum[0];
00166       for (unsigned i=1;i<sum.size();++i)
00167         if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00168 
00169       // Record position of minima
00170       ind0(i1,i2)=best_i;
00171       // Update pairwise cost for arc between v1 and v2
00172       if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00173       else               { pa12(i2,i1)+=min_v; }
00174     }
00175   }
00176 }
00177 
00178 
00179 //: Compute optimal choice for dep.v0 given v1 and v2
00180 //  Includes cost depending on (v0,v1,v2) as well as pairwise and
00181 //  node costs.
00182 // tri_cost(i,j,k) is cost of associating smallest node index
00183 // with i, next with j and largest node index with k.
00184 void mmn_dp_solver::process_dep2t(const mmn_dependancy& dep,
00185                                   const vil_image_view<double>& tri_cost)
00186 {
00187   // n_dep==2
00188   // dep->v0 depends on dep->v1 and dep->v2
00189   // dep->v0 depends on dep->v1 through arc dep->arc1
00190   const vnl_vector<double>& nc0 = nc_[dep.v0];
00191   const vnl_vector<double>& nc1 = nc_[dep.v1];
00192   const vnl_vector<double>& nc2 = nc_[dep.v2];
00193   const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00194   const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00195   vnl_matrix<double>& pa12 = pc_[dep.arc12];
00196   vnl_matrix<int>& ind0 = index2_[dep.v0];
00197 
00198   // Create a re-ordered view of tri_cost, so we can use tc(i1,i2,i3)
00199   vil_image_view<double> tc=mmn_unorder_cost(tri_cost,
00200                                              dep.v0,dep.v1,dep.v2);
00201   vcl_ptrdiff_t tc_step0=tc.istep();
00202 
00203   if (pa12.size()==0)
00204   {
00205     if (dep.v1<dep.v2)
00206       pa12.set_size(nc1.size(),nc2.size());
00207     else
00208       pa12.set_size(nc2.size(),nc1.size());
00209     pa12.fill(0.0);
00210   }
00211 
00212   // i0[i1,i2] to the optimal choice of node v0 if v1 is i1, v2 is i2
00213   ind0.set_size(nc1.size(),nc2.size());
00214 
00215   for (unsigned i1=0;i1<nc1.size();++i1)
00216   {
00217     vnl_vector<double> sum0(nc0);
00218     if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00219     else               sum0+=pa1.get_row(i1);
00220 
00221     for (unsigned i2=0;i2<nc2.size();++i2)
00222     {
00223       vnl_vector<double> sum(sum0);
00224       if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00225       else               sum+=pa2.get_row(i2);
00226 
00227       // sum[i] is the cost of choosing i, given (i1,i2)
00228       // Select minimum
00229       unsigned best_i=0;
00230       const double *tci=&tc(0,i1,i2);
00231       double min_v=sum[0]+tci[0];
00232       tci+=tc_step0; // move to element 1
00233       for (unsigned i=1;i<sum.size();++i,tci+=tc_step0)
00234       {
00235         sum[i]+=(*tci);
00236         if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00237       }
00238 
00239       // Record position of minima
00240       ind0(i1,i2)=best_i;
00241       // Update pairwise cost for arc between v1 and v2
00242       if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00243       else               { pa12(i2,i1)+=min_v; }
00244     }
00245   }
00246 }
00247 
00248 
00249 double mmn_dp_solver::solve(
00250                  const vcl_vector<vnl_vector<double> >& node_cost,
00251                  const vcl_vector<vnl_matrix<double> >& pair_cost,
00252                  vcl_vector<unsigned>& x)
00253 {
00254   nc_ = node_cost;
00255   for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00256   for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00257 
00258   if (deps_.size()==0)
00259   {
00260     vcl_cerr<<"No dependencies.\n";
00261     return 999.99;
00262   }
00263 
00264   // Process dependencies in given order
00265   vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00266   for (;dep!=deps_.end();dep++)
00267   {
00268     if (dep->n_dep==1) process_dep1(*dep);
00269     else               process_dep2(*dep);
00270   }
00271 
00272   const vnl_vector<double>& root_cost = nc_[root()];
00273   unsigned best_i=0;
00274   double min_v=root_cost[0];
00275   for (unsigned i=1;i<root_cost.size();++i)
00276     if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00277 
00278   backtrace(best_i,x);
00279   return min_v;
00280 }
00281 
00282 double mmn_dp_solver::solve(
00283                  const vcl_vector<vnl_vector<double> >& node_cost,
00284                  const vcl_vector<vnl_matrix<double> >& pair_cost,
00285                  const vcl_vector<vil_image_view<double> >& tri_cost,
00286                  vcl_vector<unsigned>& x)
00287 {
00288   nc_ = node_cost;
00289   for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00290   for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00291 
00292   if (deps_.size()==0)
00293   {
00294     vcl_cerr<<"No dependencies.\n";
00295     return 999.99;
00296   }
00297 
00298   // Process dependencies in given order
00299   vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00300   for (;dep!=deps_.end();dep++)
00301   {
00302     if (dep->n_dep==1) process_dep1(*dep);
00303     else
00304     {
00305       if (dep->tri1==mmn_no_tri) process_dep2(*dep);
00306       else
00307       {
00308         // dep->v0 depends on arcs and a triplet relationship
00309         assert(dep->tri1 < tri_cost.size());
00310         process_dep2t(*dep,tri_cost[dep->tri1]);
00311       }
00312     }
00313   }
00314 
00315   const vnl_vector<double>& root_cost = nc_[root()];
00316   unsigned best_i=0;
00317   double min_v=root_cost[0];
00318   for (unsigned i=1;i<root_cost.size();++i)
00319     if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00320 
00321   backtrace(best_i,x);
00322   return min_v;
00323 }
00324 
00325 
00326 //: Compute optimal values for x[i] given that root node is root_value
00327 //  Assumes that solve() has been already called.
00328 void mmn_dp_solver::backtrace(unsigned root_value,vcl_vector<unsigned>& x)
00329 {
00330   x.resize(nc_.size());
00331   x[root()]=root_value;
00332 
00333   // Perform backtracing to find optimal solution.
00334   for (int i=deps_.size()-1; i>=0; --i)
00335   {
00336     unsigned v0=deps_[i].v0;
00337     unsigned v1=deps_[i].v1;
00338     if (deps_[i].n_dep==1)
00339        x[v0]=index1_[v0][x[v1]];
00340     else
00341     {
00342       const vnl_matrix<int>& ind0 = index2_[v0];
00343       x[v0]=ind0(x[v1],x[deps_[i].v2]);
00344     }
00345   }
00346 }
00347 
00348 //=======================================================================
00349 // Method: set_from_stream
00350 //=======================================================================
00351 //: Initialise from a string stream
00352 bool mmn_dp_solver::set_from_stream(vcl_istream &is)
00353 {
00354   // Cycle through stream and produce a map of properties
00355   vcl_string s = mbl_parse_block(is);
00356   vcl_istringstream ss(s);
00357   mbl_read_props_type props = mbl_read_props_ws(ss);
00358 
00359   // No properties expected.
00360 
00361   // Check for unused props
00362   mbl_read_props_look_for_unused_props(
00363       "mmn_dp_solver::set_from_stream", props, mbl_read_props_type());
00364   return true;
00365 }
00366 
00367 
00368 //=======================================================================
00369 // Method: version_no
00370 //=======================================================================
00371 
00372 short mmn_dp_solver::version_no() const
00373 {
00374   return 1;
00375 }
00376 
00377 //=======================================================================
00378 // Method: is_a
00379 //=======================================================================
00380 
00381 vcl_string mmn_dp_solver::is_a() const
00382 {
00383   return vcl_string("mmn_dp_solver");
00384 }
00385 
00386 //: Create a copy on the heap and return base class pointer
00387 mmn_solver* mmn_dp_solver::clone() const
00388 {
00389   return new mmn_dp_solver(*this);
00390 }
00391 
00392 //=======================================================================
00393 // Method: print
00394 //=======================================================================
00395 
00396 void mmn_dp_solver::print_summary(vcl_ostream& /*os*/) const
00397 {
00398 }
00399 
00400 //=======================================================================
00401 // Method: save
00402 //=======================================================================
00403 
00404 void mmn_dp_solver::b_write(vsl_b_ostream& bfs) const
00405 {
00406   vsl_b_write(bfs,version_no());
00407   vsl_b_write(bfs,unsigned(deps_.size()));
00408   for (unsigned i=0;i<deps_.size();++i)
00409     vsl_b_write(bfs,deps_[i]);
00410 }
00411 
00412 //=======================================================================
00413 // Method: load
00414 //=======================================================================
00415 
00416 void mmn_dp_solver::b_read(vsl_b_istream& bfs)
00417 {
00418   if (!bfs) return;
00419   short version;
00420   unsigned n;
00421   vsl_b_read(bfs,version);
00422   switch (version)
00423   {
00424     case (1):
00425       vsl_b_read(bfs,n);
00426       deps_.resize(n);
00427       for (unsigned i=0;i<n;++i) vsl_b_read(bfs,deps_[i]);
00428       break;
00429     default:
00430       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00431                << "           Unknown version number "<< version << vcl_endl;
00432       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00433       return;
00434   }
00435 }
00436