00001 #include "mmn_diffusion_solver.h"
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <mmn/mmn_csp_solver.h>
00011 #include <vcl_algorithm.h>
00012 #include <vcl_iterator.h>
00013 #include <vcl_sstream.h>
00014 #include <vcl_cmath.h>
00015 #include <vnl/vnl_vector_ref.h>
00016 #include <mbl/mbl_exception.h>
00017 #include <mbl/mbl_stl.h>
00018 #include <mbl/mbl_stl_pred.h>
00019
00020
00021 unsigned mmn_diffusion_solver::gNCONVERGED=3;
00022 unsigned mmn_diffusion_solver::gACS_CHECK_PERIOD=10;
00023
00024
00025 mmn_diffusion_solver::mmn_diffusion_solver()
00026 : nnodes_(0), max_iterations_(2000), min_iterations_(200), epsilon_(1.0E-5),
00027 verbose_(false)
00028 {
00029 init();
00030 }
00031
00032
00033 mmn_diffusion_solver::mmn_diffusion_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00034 : max_iterations_(2000), min_iterations_(200), epsilon_(1.0E-5), verbose_(false)
00035 {
00036 init();
00037 set_arcs(num_nodes,arcs);
00038 }
00039
00040 void mmn_diffusion_solver::init()
00041 {
00042 count_=0;
00043 max_delta_=-1.0;
00044 nConverging_=0;
00045 soln_val_prev_=-1.0E30;
00046 }
00047
00048
00049 void mmn_diffusion_solver::set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00050 {
00051 nnodes_=num_nodes;
00052 arcs_ = arcs;
00053
00054 unsigned max_node=0;
00055 for (unsigned i=0; i<arcs.size();++i)
00056 {
00057 max_node=vcl_max(max_node,arcs[i].max_v());
00058 }
00059 if (nnodes_ != max_node+1)
00060 {
00061 vcl_cerr<<"Arcs appear to be inconsistent with number of nodes in mmn_diffusion_solver::set_arcs\n"
00062 <<"Max node in Arcs is: "<<max_node<<" but number of nodes= "<<nnodes_ << '\n';
00063 }
00064
00065 graph_.build(nnodes_,arcs_);
00066
00067 arc_costs_.clear();
00068 arc_costs_phi_.clear();
00069 node_costs_.clear();
00070 node_costs_phi_.clear();
00071 phi_.clear();
00072 phi_upd_.clear();
00073 u_.clear();
00074
00075 phi_.resize(nnodes_);
00076 phi_upd_=phi_;
00077 arc_costs_.resize(nnodes_);
00078 arc_costs_phi_.resize(nnodes_);
00079 u_.resize(nnodes_);
00080
00081
00082 max_iterations_ = min_iterations_ + 2*(nnodes_ * arcs_.size());
00083 }
00084
00085
00086 vcl_pair<bool,double> mmn_diffusion_solver::operator()(const vcl_vector<vnl_vector<double> >& node_costs,
00087 const vcl_vector<vnl_matrix<double> >& pair_costs,
00088 vcl_vector<unsigned>& x)
00089 {
00090 init();
00091
00092 x.resize(nnodes_);
00093 vcl_fill(x.begin(),x.end(),0);
00094
00095 node_costs_.resize(nnodes_);
00096 for (unsigned i=0;i<nnodes_;++i)
00097 {
00098
00099 node_costs_[i] = - node_costs[i];
00100 }
00101 node_costs_phi_ = node_costs_;
00102
00103 const vcl_vector<vcl_vector<vcl_pair<unsigned,unsigned> > >& neighbourhoods=graph_.node_data();
00104 for (unsigned inode=0; inode<neighbourhoods.size();++inode)
00105 {
00106 const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=neighbourhoods[inode];
00107 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00108 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00109 while (neighIter != neighIterEnd)
00110 {
00111 unsigned arcId=neighIter->second;
00112 vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00113 const vnl_matrix<double >& srcArcCosts=pair_costs[arcId];
00114 mmn_arc& arc=arcs_[arcId];
00115 unsigned v1=arc.v1;
00116 unsigned v2=arc.v2;
00117 unsigned minv=arc.min_v();
00118 if (inode!=v1 && inode!=v2)
00119 {
00120 vcl_string msg("Graph inconsistency in mmn_diffusion_solver::operator()\n");
00121 vcl_ostringstream os;
00122 os <<"Source node is "<<inode<<" but arc to alleged neighbour joins nodes "<<v1<<"\t to "<<v2<<'\n';
00123 msg+= os.str();
00124 vcl_cerr<<msg<<vcl_endl;
00125 throw mbl_exception_abort(msg);
00126 }
00127
00128 if (inode==minv)
00129 {
00130 linkCosts=srcArcCosts;
00131 }
00132 else
00133 {
00134 linkCosts=srcArcCosts.transpose();
00135 }
00136 linkCosts*= -1.0;
00137 unsigned nstates=linkCosts.rows();
00138 phi_[inode][neighIter->first].set_size(nstates);
00139 phi_[inode][neighIter->first].fill(0.0);
00140
00141 u_[inode][neighIter->first].set_size(nstates);
00142 u_[inode][neighIter->first].fill(0.0);
00143
00144 ++neighIter;
00145 }
00146 }
00147
00148 arc_costs_phi_ = arc_costs_;
00149 node_costs_phi_ = node_costs_;
00150 phi_upd_ = phi_;
00151
00152
00153 vcl_vector<unsigned > random_indices(nnodes_,0);
00154 mbl_stl_increments(random_indices.begin(),random_indices.end(),0);
00155 do
00156 {
00157 max_delta_=-1.0;
00158 vcl_random_shuffle(random_indices.begin(),random_indices.end());
00159
00160 for (unsigned knode=0; knode<nnodes_;++knode)
00161 {
00162 unsigned inode=random_indices[knode];
00163
00164 update_potentials_to_neighbours(inode,node_costs_phi_[inode]);
00165
00166 phi_[inode]=phi_upd_[inode];
00167
00168 transform_costs(inode);
00169 }
00170
00171 if (verbose_)
00172 {
00173 vcl_cout<<"Max potential delta at iteration "<<count_<<"\t is "<<max_delta_<<vcl_endl;
00174 }
00175 }
00176 while (continue_diffusion());
00177
00178
00179 bool ok = arc_consistent_solution(x);
00180
00181 return vcl_pair<bool,double>(ok,-solution_cost(x));
00182 }
00183
00184 void mmn_diffusion_solver::transform_costs()
00185 {
00186 for (unsigned inode=0; inode<nnodes_;++inode)
00187 {
00188 transform_costs(inode);
00189 }
00190 }
00191
00192 void mmn_diffusion_solver::transform_costs(unsigned inode)
00193 {
00194
00195
00196 unsigned nStates=node_costs_[inode].size();
00197 for (unsigned xlabel=0; xlabel<nStates;++xlabel)
00198 {
00199 const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00200
00201 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00202 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00203 double phiTot=0.0;
00204 while (neighIter != neighIterEnd)
00205 {
00206 double& phix=phi_[inode][neighIter->first][xlabel];
00207 phiTot+= phix;
00208 vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00209 vnl_matrix<double>& linkCostsPhi = arc_costs_phi_[inode][neighIter->first];
00210 unsigned nNeighStates=linkCosts.cols();
00211 const vnl_vector<double >& phiTransposed=phi_[neighIter->first][inode];
00212 for (unsigned xprime=0; xprime<nNeighStates;++xprime)
00213 {
00214
00215 linkCostsPhi(xlabel,xprime) = linkCosts(xlabel,xprime) -(phix + phiTransposed[xprime]);
00216 }
00217 ++neighIter;
00218 }
00219
00220 node_costs_phi_[inode][xlabel] = node_costs_[inode][xlabel] + phiTot;
00221 }
00222 }
00223
00224 double mmn_diffusion_solver::solution_cost(vcl_vector<unsigned>& x)
00225 {
00226
00227 double sumNodes=0.0;
00228
00229 vcl_vector<vnl_vector<double> >::const_iterator nodeIter=node_costs_.begin();
00230 vcl_vector<vnl_vector<double> >::const_iterator nodeIterEnd=node_costs_.end();
00231 vcl_vector<unsigned >::const_iterator stateIter=x.begin();
00232 while (nodeIter != nodeIterEnd)
00233 {
00234 const vnl_vector<double>& ncosts = *nodeIter;
00235 sumNodes+=ncosts[*stateIter];
00236 ++nodeIter;++stateIter;
00237 }
00238
00239
00240 vcl_vector<mmn_arc>::const_iterator arcIter=arcs_.begin();
00241 vcl_vector<mmn_arc>::const_iterator arcIterEnd=arcs_.end();
00242 double sumArcs=0.0;
00243 while (arcIter != arcIterEnd)
00244 {
00245 unsigned nodeId1=arcIter->v1;
00246 unsigned nodeId2=arcIter->v2;
00247 sumArcs += arc_costs_[nodeId1][nodeId2](x[nodeId1],x[nodeId2]);
00248 ++arcIter;
00249 }
00250 return sumNodes+sumArcs;
00251 }
00252
00253
00254 void mmn_diffusion_solver::update_potentials_to_neighbours(unsigned inode,
00255 const vnl_vector<double>& node_cost)
00256 {
00257
00258 unsigned nStates=node_cost.size();
00259 const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00260 for (unsigned xlabel=0; xlabel<nStates;++xlabel)
00261 {
00262 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00263 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00264 double du=node_cost[xlabel];
00265 while (neighIter != neighIterEnd)
00266 {
00267
00268 vnl_vector<double>& uToNeigh = u_[inode][neighIter->first];
00269 vnl_matrix<double>& linkCosts = arc_costs_phi_[inode][neighIter->first];
00270 double* pgRow=linkCosts[xlabel];
00271 uToNeigh[xlabel] = *(vcl_max_element(pgRow,pgRow+linkCosts.cols()));
00272
00273 du += uToNeigh[xlabel];
00274 ++neighIter;
00275 }
00276
00277 du /= (double(1.0+neighbours.size()));
00278
00279
00280 neighIter=neighbours.begin();
00281 while (neighIter != neighIterEnd)
00282 {
00283 vnl_vector<double>& uToNeigh = u_[inode][neighIter->first];
00284 double delta = (uToNeigh[xlabel] - du);
00285 phi_upd_[inode][neighIter->first][xlabel] += delta;
00286 max_delta_ = vcl_max(max_delta_,delta);
00287 ++neighIter;
00288 }
00289 }
00290 }
00291
00292 bool mmn_diffusion_solver::arc_consistent_solution(vcl_vector<unsigned>& x)
00293 {
00294
00295
00296
00297
00298
00299 vcl_vector<mmn_csp_solver::label_subset_t > node_labels_subset(nnodes_);
00300 vcl_vector<mmn_csp_solver::arc_labels_subset_t > links_subset(arcs_.size());
00301
00302 const double epsilon_cost = 1.0E-6;
00303 for (unsigned inode=0; inode<nnodes_;++inode)
00304 {
00305 vnl_vector<double> labelCosts=node_costs_phi_[inode];
00306
00307 double lmax=*vcl_max_element(labelCosts.begin(),labelCosts.end());
00308
00309 vcl_vector<unsigned > index(labelCosts.size(),0) ;
00310 mbl_stl_increments(index.begin(),index.end(),0);
00311
00312 mbl_stl_copy_if(index.begin(),index.end(),
00313 vcl_inserter(node_labels_subset[inode],node_labels_subset[inode].end()),
00314 mbl_stl_pred_create_index_adapter(labelCosts,
00315 mbl_stl_pred_is_near(lmax,epsilon_cost)));
00316 }
00317 for (unsigned arcId=0;arcId<arcs_.size();++arcId)
00318 {
00319
00320 double umax=-1.0E30;
00321 unsigned srcId=arcs_[arcId].min_v();
00322 unsigned targetId=arcs_[arcId].max_v();
00323
00324 vnl_vector<double>& uToNeigh = u_[srcId][targetId];
00325 vnl_matrix<double>& linkCosts = arc_costs_phi_[srcId][targetId];
00326
00327 unsigned nStates=linkCosts.rows();
00328 for (unsigned xlabel=0; xlabel<nStates;++xlabel)
00329 {
00330 double* pgRow=linkCosts[xlabel];
00331 double u = *(vcl_max_element(pgRow,pgRow+linkCosts.cols()));
00332
00333 umax = vcl_max(u,umax);
00334 uToNeigh[xlabel] = u;
00335 }
00336
00337
00338 vcl_vector<unsigned > xindex(linkCosts.rows(),0) ;
00339 mbl_stl_increments(xindex.begin(),xindex.end(),0);
00340 vcl_vector<unsigned> maxRows;
00341
00342 mbl_stl_copy_if(xindex.begin(),xindex.end(),
00343 vcl_back_inserter(maxRows),
00344 mbl_stl_pred_create_index_adapter(uToNeigh,
00345 mbl_stl_pred_is_near(umax,epsilon_cost)));
00346 vcl_vector<unsigned>::iterator rowIter=maxRows.begin();
00347 vcl_vector<unsigned>::iterator rowIterEnd=maxRows.end();
00348 while (rowIter != rowIterEnd)
00349 {
00350
00351 unsigned xlabel=*rowIter;
00352 vnl_vector_ref<double > row(linkCosts.cols(),linkCosts[xlabel]);
00353 mbl_stl_pred_is_near nearMax(umax,epsilon_cost);
00354 for (unsigned xprime=0;xprime<linkCosts.cols();++xprime)
00355 {
00356 if (nearMax(row[xprime]))
00357 {
00358 links_subset[arcId].insert(vcl_pair<unsigned ,unsigned >(xlabel,xprime));
00359 }
00360 }
00361
00362 ++rowIter;
00363 }
00364 }
00365
00366
00367 mmn_csp_solver cspSolver(nnodes_,arcs_);
00368 bool arcConsistent=cspSolver(node_labels_subset,links_subset);
00369 if (arcConsistent)
00370 {
00371 const vcl_vector<mmn_csp_solver::label_subset_t >& kernel_node_labels=cspSolver.kernel_node_labels();
00372 for (unsigned inode=0; inode<nnodes_;++inode)
00373 {
00374 x[inode]=*(kernel_node_labels[inode].begin());
00375 }
00376 }
00377 else
00378 {
00379
00380
00381 for (unsigned inode=0; inode<nnodes_;++inode)
00382 {
00383 x[inode] = *(node_labels_subset[inode].begin());
00384 }
00385 }
00386 return arcConsistent;
00387 }
00388
00389 bool mmn_diffusion_solver::continue_diffusion()
00390 {
00391 ++count_;
00392 bool retstate=true;
00393 if (max_delta_<epsilon_ || count_>max_iterations_)
00394 {
00395
00396 retstate = false;
00397 }
00398 else if (count_>min_iterations_)
00399 {
00400
00401
00402 if (count_ % gACS_CHECK_PERIOD==0)
00403 {
00404 vcl_vector<unsigned> x(nnodes_,0);
00405 bool ok = arc_consistent_solution(x);
00406 if (ok)
00407 {
00408 double soln_val=solution_cost(x);
00409 if (verbose_)
00410 {
00411 vcl_cout<<"Arc consistent solution reached. "
00412 <<"\tSolution value= "<<soln_val<<"\tprev soln val= "<<soln_val_prev_<<vcl_endl;
00413 }
00414 if (vcl_fabs(soln_val-soln_val_prev_)<epsilon_)
00415 {
00416 ++nConverging_;
00417 if (nConverging_>gNCONVERGED)
00418 {
00419 retstate = false;
00420 }
00421 }
00422 else
00423 {
00424 nConverging_=0;
00425 }
00426
00427 soln_val_prev_=soln_val;
00428 }
00429 else
00430 {
00431 if (verbose_)
00432 {
00433 vcl_cout<<"Solution is not yet arc consistent."<<vcl_endl;
00434 }
00435 nConverging_=0;
00436 }
00437 }
00438 }
00439 return retstate;
00440 }
00441