00001 #include "mmn_dp_solver.h"
00002
00003
00004
00005
00006
00007 #include <mmn/mmn_order_cost.h>
00008 #include <mmn/mmn_graph_rep1.h>
00009 #include <vcl_cassert.h>
00010 #include <vcl_cstdlib.h>
00011 #include <vcl_sstream.h>
00012
00013 #include <mbl/mbl_parse_block.h>
00014 #include <mbl/mbl_read_props.h>
00015
00016
00017 mmn_dp_solver::mmn_dp_solver()
00018 {
00019 }
00020
00021
00022 void mmn_dp_solver::set_arcs(unsigned num_nodes,
00023 const vcl_vector<mmn_arc>& arcs)
00024 {
00025
00026 vcl_vector<mmn_arc> ordered_arcs(arcs.size());
00027 for (unsigned i=0;i<arcs.size();++i)
00028 {
00029 if (arcs[i].v1<arcs[i].v2)
00030 ordered_arcs[i]= arcs[i];
00031 else
00032 ordered_arcs[i]= mmn_arc(arcs[i].v2,arcs[i].v1);
00033 }
00034
00035 mmn_graph_rep1 graph;
00036 graph.build(num_nodes,ordered_arcs);
00037 vcl_vector<mmn_dependancy> deps;
00038 if (!graph.compute_dependancies(deps,0))
00039 {
00040 vcl_cerr<<"Graph cannot be decomposed - too complex.\n"
00041 <<"Arc list: ";
00042 for (unsigned i=0;i<arcs.size();++i) vcl_cout<<arcs[i];
00043 vcl_cerr<<'\n';
00044 vcl_abort();
00045 }
00046
00047 set_dependancies(deps,num_nodes,graph.max_n_arcs());
00048 }
00049
00050
00051
00052 unsigned mmn_dp_solver::root() const
00053 {
00054 if (deps_.size()==0) return 0;
00055 return deps_[deps_.size()-1].v1;
00056 }
00057
00058
00059 void mmn_dp_solver::set_dependancies(const vcl_vector<mmn_dependancy>& deps,
00060 unsigned n_nodes, unsigned max_n_arcs)
00061 {
00062 deps_ = deps;
00063 nc_.resize(n_nodes);
00064 pc_.resize(max_n_arcs);
00065 index1_.resize(n_nodes);
00066 index2_.resize(n_nodes);
00067 }
00068
00069 void mmn_dp_solver::process_dep1(const mmn_dependancy& dep)
00070 {
00071
00072 const vnl_vector<double>& nc0 = nc_[dep.v0];
00073 vnl_vector<double>& nc1 = nc_[dep.v1];
00074 vnl_matrix<double>& p = pc_[dep.arc1];
00075 vcl_vector<unsigned>& i0 = index1_[dep.v0];
00076
00077
00078 if (dep.v0<dep.v1)
00079 {
00080 assert(p.rows()==nc0.size());
00081 assert(p.cols()==nc1.size());
00082 }
00083 else
00084 {
00085 if (p.rows()!=nc1.size())
00086 {
00087 vcl_cerr<<"p.rows()="<<p.rows()<<"p.cols()="<<p.cols()
00088 <<" nc0.size()="<<nc0.size()
00089 <<" nc1.size()="<<nc1.size()<<vcl_endl
00090 <<"dep: "<<dep<<vcl_endl;
00091 }
00092 assert(p.rows()==nc1.size());
00093 assert(p.cols()==nc0.size());
00094 }
00095
00096
00097 i0.resize(nc1.size());
00098 for (unsigned j=0;j<nc1.size();++j)
00099 {
00100 double min_v;
00101 unsigned best_i=0;
00102 if (dep.v0<dep.v1)
00103 {
00104 min_v=nc0[0]+p(0,j);
00105 for (unsigned i=1;i<nc0.size();++i)
00106 {
00107 double v=nc0[i]+p(i,j);
00108 if (v<min_v) { min_v=v; best_i=i; }
00109 }
00110 }
00111 else
00112 {
00113 min_v=nc0[0]+p(j,0);
00114 for (unsigned i=1;i<nc0.size();++i)
00115 {
00116 double v=nc0[i]+p(j,i);
00117 if (v<min_v) { min_v=v; best_i=i; }
00118 }
00119 }
00120 i0[j]=best_i;
00121 nc1[j]+=min_v;
00122 }
00123 }
00124
00125 void mmn_dp_solver::process_dep2(const mmn_dependancy& dep)
00126 {
00127
00128
00129
00130 const vnl_vector<double>& nc0 = nc_[dep.v0];
00131 const vnl_vector<double>& nc1 = nc_[dep.v1];
00132 const vnl_vector<double>& nc2 = nc_[dep.v2];
00133 const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00134 const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00135 vnl_matrix<double>& pa12 = pc_[dep.arc12];
00136 vnl_matrix<int>& ind0 = index2_[dep.v0];
00137
00138 if (pa12.size()==0)
00139 {
00140 if (dep.v1<dep.v2)
00141 pa12.set_size(nc1.size(),nc2.size());
00142 else
00143 pa12.set_size(nc2.size(),nc1.size());
00144 pa12.fill(0.0);
00145 }
00146
00147
00148 ind0.set_size(nc1.size(),nc2.size());
00149
00150 for (unsigned i1=0;i1<nc1.size();++i1)
00151 {
00152 vnl_vector<double> sum0(nc0);
00153 if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00154 else sum0+=pa1.get_row(i1);
00155
00156 for (unsigned i2=0;i2<nc2.size();++i2)
00157 {
00158 vnl_vector<double> sum(sum0);
00159 if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00160 else sum+=pa2.get_row(i2);
00161
00162
00163
00164 unsigned best_i=0;
00165 double min_v=sum[0];
00166 for (unsigned i=1;i<sum.size();++i)
00167 if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00168
00169
00170 ind0(i1,i2)=best_i;
00171
00172 if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00173 else { pa12(i2,i1)+=min_v; }
00174 }
00175 }
00176 }
00177
00178
00179
00180
00181
00182
00183
00184 void mmn_dp_solver::process_dep2t(const mmn_dependancy& dep,
00185 const vil_image_view<double>& tri_cost)
00186 {
00187
00188
00189
00190 const vnl_vector<double>& nc0 = nc_[dep.v0];
00191 const vnl_vector<double>& nc1 = nc_[dep.v1];
00192 const vnl_vector<double>& nc2 = nc_[dep.v2];
00193 const vnl_matrix<double>& pa1 = pc_[dep.arc1];
00194 const vnl_matrix<double>& pa2 = pc_[dep.arc2];
00195 vnl_matrix<double>& pa12 = pc_[dep.arc12];
00196 vnl_matrix<int>& ind0 = index2_[dep.v0];
00197
00198
00199 vil_image_view<double> tc=mmn_unorder_cost(tri_cost,
00200 dep.v0,dep.v1,dep.v2);
00201 vcl_ptrdiff_t tc_step0=tc.istep();
00202
00203 if (pa12.size()==0)
00204 {
00205 if (dep.v1<dep.v2)
00206 pa12.set_size(nc1.size(),nc2.size());
00207 else
00208 pa12.set_size(nc2.size(),nc1.size());
00209 pa12.fill(0.0);
00210 }
00211
00212
00213 ind0.set_size(nc1.size(),nc2.size());
00214
00215 for (unsigned i1=0;i1<nc1.size();++i1)
00216 {
00217 vnl_vector<double> sum0(nc0);
00218 if (dep.v0<dep.v1) sum0+=pa1.get_column(i1);
00219 else sum0+=pa1.get_row(i1);
00220
00221 for (unsigned i2=0;i2<nc2.size();++i2)
00222 {
00223 vnl_vector<double> sum(sum0);
00224 if (dep.v0<dep.v2) sum+=pa2.get_column(i2);
00225 else sum+=pa2.get_row(i2);
00226
00227
00228
00229 unsigned best_i=0;
00230 const double *tci=&tc(0,i1,i2);
00231 double min_v=sum[0]+tci[0];
00232 tci+=tc_step0;
00233 for (unsigned i=1;i<sum.size();++i,tci+=tc_step0)
00234 {
00235 sum[i]+=(*tci);
00236 if (sum[i]<min_v) { min_v=sum[i]; best_i=i; }
00237 }
00238
00239
00240 ind0(i1,i2)=best_i;
00241
00242 if (dep.v1<dep.v2) { pa12(i1,i2)+=min_v; }
00243 else { pa12(i2,i1)+=min_v; }
00244 }
00245 }
00246 }
00247
00248
00249 double mmn_dp_solver::solve(
00250 const vcl_vector<vnl_vector<double> >& node_cost,
00251 const vcl_vector<vnl_matrix<double> >& pair_cost,
00252 vcl_vector<unsigned>& x)
00253 {
00254 nc_ = node_cost;
00255 for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00256 for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00257
00258 if (deps_.size()==0)
00259 {
00260 vcl_cerr<<"No dependencies.\n";
00261 return 999.99;
00262 }
00263
00264
00265 vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00266 for (;dep!=deps_.end();dep++)
00267 {
00268 if (dep->n_dep==1) process_dep1(*dep);
00269 else process_dep2(*dep);
00270 }
00271
00272 const vnl_vector<double>& root_cost = nc_[root()];
00273 unsigned best_i=0;
00274 double min_v=root_cost[0];
00275 for (unsigned i=1;i<root_cost.size();++i)
00276 if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00277
00278 backtrace(best_i,x);
00279 return min_v;
00280 }
00281
00282 double mmn_dp_solver::solve(
00283 const vcl_vector<vnl_vector<double> >& node_cost,
00284 const vcl_vector<vnl_matrix<double> >& pair_cost,
00285 const vcl_vector<vil_image_view<double> >& tri_cost,
00286 vcl_vector<unsigned>& x)
00287 {
00288 nc_ = node_cost;
00289 for (unsigned i=0;i<pair_cost.size();++i) pc_[i]=pair_cost[i];
00290 for (unsigned i=pair_cost.size();i<pc_.size();++i) pc_[i].set_size(0,0);
00291
00292 if (deps_.size()==0)
00293 {
00294 vcl_cerr<<"No dependencies.\n";
00295 return 999.99;
00296 }
00297
00298
00299 vcl_vector<mmn_dependancy>::const_iterator dep=deps_.begin();
00300 for (;dep!=deps_.end();dep++)
00301 {
00302 if (dep->n_dep==1) process_dep1(*dep);
00303 else
00304 {
00305 if (dep->tri1==mmn_no_tri) process_dep2(*dep);
00306 else
00307 {
00308
00309 assert(dep->tri1 < tri_cost.size());
00310 process_dep2t(*dep,tri_cost[dep->tri1]);
00311 }
00312 }
00313 }
00314
00315 const vnl_vector<double>& root_cost = nc_[root()];
00316 unsigned best_i=0;
00317 double min_v=root_cost[0];
00318 for (unsigned i=1;i<root_cost.size();++i)
00319 if (root_cost[i]<min_v) { min_v=root_cost[i]; best_i=i; }
00320
00321 backtrace(best_i,x);
00322 return min_v;
00323 }
00324
00325
00326
00327
00328 void mmn_dp_solver::backtrace(unsigned root_value,vcl_vector<unsigned>& x)
00329 {
00330 x.resize(nc_.size());
00331 x[root()]=root_value;
00332
00333
00334 for (int i=deps_.size()-1; i>=0; --i)
00335 {
00336 unsigned v0=deps_[i].v0;
00337 unsigned v1=deps_[i].v1;
00338 if (deps_[i].n_dep==1)
00339 x[v0]=index1_[v0][x[v1]];
00340 else
00341 {
00342 const vnl_matrix<int>& ind0 = index2_[v0];
00343 x[v0]=ind0(x[v1],x[deps_[i].v2]);
00344 }
00345 }
00346 }
00347
00348
00349
00350
00351
00352 bool mmn_dp_solver::set_from_stream(vcl_istream &is)
00353 {
00354
00355 vcl_string s = mbl_parse_block(is);
00356 vcl_istringstream ss(s);
00357 mbl_read_props_type props = mbl_read_props_ws(ss);
00358
00359
00360
00361
00362 mbl_read_props_look_for_unused_props(
00363 "mmn_dp_solver::set_from_stream", props, mbl_read_props_type());
00364 return true;
00365 }
00366
00367
00368
00369
00370
00371
00372 short mmn_dp_solver::version_no() const
00373 {
00374 return 1;
00375 }
00376
00377
00378
00379
00380
00381 vcl_string mmn_dp_solver::is_a() const
00382 {
00383 return vcl_string("mmn_dp_solver");
00384 }
00385
00386
00387 mmn_solver* mmn_dp_solver::clone() const
00388 {
00389 return new mmn_dp_solver(*this);
00390 }
00391
00392
00393
00394
00395
00396 void mmn_dp_solver::print_summary(vcl_ostream& ) const
00397 {
00398 }
00399
00400
00401
00402
00403
00404 void mmn_dp_solver::b_write(vsl_b_ostream& bfs) const
00405 {
00406 vsl_b_write(bfs,version_no());
00407 vsl_b_write(bfs,unsigned(deps_.size()));
00408 for (unsigned i=0;i<deps_.size();++i)
00409 vsl_b_write(bfs,deps_[i]);
00410 }
00411
00412
00413
00414
00415
00416 void mmn_dp_solver::b_read(vsl_b_istream& bfs)
00417 {
00418 if (!bfs) return;
00419 short version;
00420 unsigned n;
00421 vsl_b_read(bfs,version);
00422 switch (version)
00423 {
00424 case (1):
00425 vsl_b_read(bfs,n);
00426 deps_.resize(n);
00427 for (unsigned i=0;i<n;++i) vsl_b_read(bfs,deps_[i]);
00428 break;
00429 default:
00430 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00431 << " Unknown version number "<< version << vcl_endl;
00432 bfs.is().clear(vcl_ios::badbit);
00433 return;
00434 }
00435 }
00436