Go to the documentation of this file.
00001 #include "mmn_diffusion_solver.h"
00002 //:
00003 // \file
00004 // \brief Run diffusion algorithm to solve max sum problem
00005 // \author Martin Roberts
00006 //
00007 // See  T Werner. A Linear Programming Approach to Max-sum problem: A review;
00008 // IEEE Trans on Pattern Recog & Machine Intell, July 2007
00010 #include <mmn/mmn_csp_solver.h>
00011 #include <vcl_algorithm.h>
00012 #include <vcl_iterator.h>
00013 #include <vcl_sstream.h>
00014 #include <vcl_cmath.h>
00015 #include <vnl/vnl_vector_ref.h>
00016 #include <mbl/mbl_exception.h>
00017 #include <mbl/mbl_stl.h>
00018 #include <mbl/mbl_stl_pred.h>
00020 //Static magic numbers
00021 unsigned mmn_diffusion_solver::gNCONVERGED=3;
00022 unsigned mmn_diffusion_solver::gACS_CHECK_PERIOD=10;
00024 //: Default constructor
00025 mmn_diffusion_solver::mmn_diffusion_solver()
00026   : nnodes_(0), max_iterations_(2000), min_iterations_(200), epsilon_(1.0E-5),
00027     verbose_(false)
00028 {
00029     init();
00030 }
00032 //: Construct with arcs
00033 mmn_diffusion_solver::mmn_diffusion_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00034 : max_iterations_(2000), min_iterations_(200), epsilon_(1.0E-5), verbose_(false)
00035 {
00036     init();
00037     set_arcs(num_nodes,arcs);
00038 }
00040 void mmn_diffusion_solver::init()
00041 {
00042     count_=0;
00043     max_delta_=-1.0;
00044     nConverging_=0;
00045     soln_val_prev_=-1.0E30;
00046 }
00048 //: Pass in the arcs, which are then used to build the graph object
00049 void mmn_diffusion_solver::set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00050 {
00051     nnodes_=num_nodes;
00052     arcs_ = arcs;
00053     //Verify consistency
00054     unsigned max_node=0;
00055     for (unsigned i=0; i<arcs.size();++i)
00056     {
00057         max_node=vcl_max(max_node,arcs[i].max_v());
00058     }
00059     if (nnodes_ != max_node+1)
00060     {
00061         vcl_cerr<<"Arcs appear to be inconsistent with number of nodes in mmn_diffusion_solver::set_arcs\n"
00062                 <<"Max node in Arcs is: "<<max_node<<" but number of nodes= "<<nnodes_ << '\n';
00063     }
00067     arc_costs_.clear();
00068     arc_costs_phi_.clear();
00069     node_costs_.clear();
00070     node_costs_phi_.clear();
00071     phi_.clear();
00072     phi_upd_.clear();
00073     u_.clear();
00075     phi_.resize(nnodes_);
00076     phi_upd_=phi_;
00077     arc_costs_.resize(nnodes_);
00078     arc_costs_phi_.resize(nnodes_);
00079     u_.resize(nnodes_);
00081     //Set max iterations, somewhat arbitrarily, increasing with nodes and arcs
00082     max_iterations_ = min_iterations_ + 2*(nnodes_ * arcs_.size());
00083 }
00085 //: Run the algorithm
00086 vcl_pair<bool,double> mmn_diffusion_solver::operator()(const vcl_vector<vnl_vector<double> >& node_costs,
00087                                   const vcl_vector<vnl_matrix<double> >& pair_costs,
00088                                   vcl_vector<unsigned>& x)
00089 {
00090     init();
00092     x.resize(nnodes_);
00093     vcl_fill(x.begin(),x.end(),0);
00095     node_costs_.resize(nnodes_);
00096     for (unsigned i=0;i<nnodes_;++i)
00097     {
00098         //: Negate costs to convert to log prob (i.e. internally we do maximisation)
00099         node_costs_[i] = - node_costs[i];
00100     }
00101     node_costs_phi_ = node_costs_;
00102     //Initialise potential structure and neighbourhood cost representation
00103     const vcl_vector<vcl_vector<vcl_pair<unsigned,unsigned> > >& neighbourhoods=graph_.node_data();
00104     for (unsigned inode=0; inode<neighbourhoods.size();++inode)
00105     {
00106         const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=neighbourhoods[inode];
00107         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00108         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00109         while (neighIter != neighIterEnd) //do all neighbours of this node
00110         {
00111             unsigned arcId=neighIter->second;
00112             vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00113             const vnl_matrix<double >& srcArcCosts=pair_costs[arcId];
00114             mmn_arc& arc=arcs_[arcId];
00115             unsigned v1=arc.v1;
00116             unsigned v2=arc.v2;
00117             unsigned minv=arc.min_v();
00118             if (inode!=v1 && inode!=v2)
00119             {
00120                 vcl_string msg("Graph inconsistency in mmn_diffusion_solver::operator()\n");
00121                 vcl_ostringstream os;
00122                 os <<"Source node is "<<inode<<" but arc to alleged neighbour joins nodes "<<v1<<"\t to "<<v2<<'\n';
00123                 msg+= os.str();
00124                 vcl_cerr<<msg<<vcl_endl;
00125                 throw mbl_exception_abort(msg);
00126             }
00128             if (inode==minv)
00129             {
00130                 linkCosts=srcArcCosts;
00131             }
00132             else //transpose arc costs as this node is the 2nd element in the pair_costs matrix
00133             {
00134                 linkCosts=srcArcCosts.transpose();
00135             }
00136             linkCosts*= -1.0; //convert to maximising log prob (not min -log prob)
00137             unsigned nstates=linkCosts.rows(); //number of source states
00138             phi_[inode][neighIter->first].set_size(nstates);
00139             phi_[inode][neighIter->first].fill(0.0); //set all initial potentials to zero
00141             u_[inode][neighIter->first].set_size(nstates);
00142             u_[inode][neighIter->first].fill(0.0); //set all initial potential changes to zero
00143                                                    //
00144             ++neighIter; //next neighbour of this node
00145         }
00146     } //next node
00148     arc_costs_phi_ = arc_costs_;
00149     node_costs_phi_ = node_costs_;
00150     phi_upd_ = phi_;
00152     //Now keep repeating node-pencil averaging
00153     vcl_vector<unsigned > random_indices(nnodes_,0);
00154     mbl_stl_increments(random_indices.begin(),random_indices.end(),0);
00155     do
00156     {
00157         max_delta_=-1.0;
00158         vcl_random_shuffle(random_indices.begin(),random_indices.end());
00159         //Randomise the order of pencils
00160         for (unsigned knode=0; knode<nnodes_;++knode)
00161         {
00162             unsigned inode=random_indices[knode];
00163             //Do all pencils from this node
00164             update_potentials_to_neighbours(inode,node_costs_phi_[inode]);
00166             phi_[inode]=phi_upd_[inode];
00167             //Update costs for node and its pencils given phi
00168             transform_costs(inode);
00169         }
00171         if (verbose_)
00172         {
00173             vcl_cout<<"Max potential delta at iteration "<<count_<<"\t is "<<max_delta_<<vcl_endl;
00174         }
00175     }
00176     while (continue_diffusion());
00178     //Now check final "trivial" solution is arc consistent
00179     bool ok = arc_consistent_solution(x);
00181     return vcl_pair<bool,double>(ok,-solution_cost(x));
00182 }
00184 void mmn_diffusion_solver::transform_costs()
00185 {
00186     for (unsigned inode=0; inode<nnodes_;++inode)
00187     {
00188         transform_costs(inode);
00189     }
00190 }
00192 void mmn_diffusion_solver::transform_costs(unsigned inode)
00193 {
00194     //Add on potentials to nodes and subtract from arcs to transform to equivalent problem
00195     // with (hopefully) a tighter upper bound
00196     unsigned nStates=node_costs_[inode].size();
00197     for (unsigned xlabel=0; xlabel<nStates;++xlabel) //Loop over labels of node
00198     {
00199         const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00201         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00202         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00203         double phiTot=0.0; //total added to node cost
00204         while (neighIter != neighIterEnd) //Loop over all my neighbours
00205         {
00206             double& phix=phi_[inode][neighIter->first][xlabel];
00207             phiTot+= phix; //add contribution for this arc
00208             vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00209             vnl_matrix<double>& linkCostsPhi = arc_costs_phi_[inode][neighIter->first];
00210             unsigned nNeighStates=linkCosts.cols();
00211             const vnl_vector<double >& phiTransposed=phi_[neighIter->first][inode];
00212             for (unsigned xprime=0; xprime<nNeighStates;++xprime) //Loop over labels of neighbour
00213             {
00214                 // Update equivalent link costs given phi
00215                 linkCostsPhi(xlabel,xprime) = linkCosts(xlabel,xprime) -(phix + phiTransposed[xprime]);
00216             }
00217             ++neighIter;
00218         }
00219         //Update equivalent node costs given phi
00220         node_costs_phi_[inode][xlabel] = node_costs_[inode][xlabel] + phiTot;
00221     } //labels of this node
00222 }
00224 double mmn_diffusion_solver::solution_cost(vcl_vector<unsigned>& x)
00225 {
00226     //: Calculate objective function for solution x
00227     double sumNodes=0.0;
00228     //Sum over all nodes
00229     vcl_vector<vnl_vector<double> >::const_iterator nodeIter=node_costs_.begin();
00230     vcl_vector<vnl_vector<double> >::const_iterator nodeIterEnd=node_costs_.end();
00231     vcl_vector<unsigned >::const_iterator stateIter=x.begin();
00232     while (nodeIter != nodeIterEnd)
00233     {
00234         const vnl_vector<double>& ncosts = *nodeIter;
00235         sumNodes+=ncosts[*stateIter];
00236         ++nodeIter;++stateIter;
00237     }
00239     // Sum over all arcs
00240     vcl_vector<mmn_arc>::const_iterator arcIter=arcs_.begin();
00241     vcl_vector<mmn_arc>::const_iterator arcIterEnd=arcs_.end();
00242     double sumArcs=0.0;
00243     while (arcIter != arcIterEnd)
00244     {
00245         unsigned nodeId1=arcIter->v1;
00246         unsigned nodeId2=arcIter->v2;
00247         sumArcs += arc_costs_[nodeId1][nodeId2](x[nodeId1],x[nodeId2]);
00248         ++arcIter;
00249     }
00250     return sumNodes+sumArcs;
00251 }
00254 void mmn_diffusion_solver::update_potentials_to_neighbours(unsigned inode,
00255                                                            const vnl_vector<double>& node_cost)
00256 {
00257     //Update all potentials from this node to its neighbours
00258     unsigned nStates=node_cost.size();
00259     const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00260     for (unsigned xlabel=0; xlabel<nStates;++xlabel) //loop over my labels (i.e. each pencil)
00261     {
00262         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00263         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00264         double du=node_cost[xlabel];
00265         while (neighIter != neighIterEnd) //Loop over all my neighbours
00266         {
00267             //Compute contribution of this neighbour to the node-pencil averaging
00268             vnl_vector<double>& uToNeigh = u_[inode][neighIter->first];
00269             vnl_matrix<double>& linkCosts = arc_costs_phi_[inode][neighIter->first];
00270             double* pgRow=linkCosts[xlabel];
00271             uToNeigh[xlabel] = *(vcl_max_element(pgRow,pgRow+linkCosts.cols())); //max arc cost of pencil
00273             du += uToNeigh[xlabel];
00274             ++neighIter;
00275         }
00277         du /= (double(1.0+neighbours.size())); //average
00279         //Now update potentials given du
00280         neighIter=neighbours.begin();
00281         while (neighIter != neighIterEnd) //Loop over all my neighbours
00282         {
00283             vnl_vector<double>& uToNeigh = u_[inode][neighIter->first];
00284             double delta = (uToNeigh[xlabel] - du);
00285             phi_upd_[inode][neighIter->first][xlabel] += delta;
00286             max_delta_ = vcl_max(max_delta_,delta);
00287             ++neighIter;
00288         }
00289     }
00290 }
00292 bool mmn_diffusion_solver::arc_consistent_solution(vcl_vector<unsigned>& x)
00293 {
00294     // Find for each node the maximum label(s), and the maximal connecting arcs
00295     // Check if this set form an arc consistent solution
00296     // If so set x to kernel
00297     // Otherwise x is set to the first maximal node label
00299     vcl_vector<mmn_csp_solver::label_subset_t > node_labels_subset(nnodes_);
00300     vcl_vector<mmn_csp_solver::arc_labels_subset_t > links_subset(arcs_.size());
00302     const double epsilon_cost = 1.0E-6;
00303     for (unsigned inode=0; inode<nnodes_;++inode) //Loop over nodes
00304     {
00305         vnl_vector<double> labelCosts=node_costs_phi_[inode];
00306         //: Find (possibly non-unique) maximal label value
00307         double lmax=*vcl_max_element(labelCosts.begin(),labelCosts.end());
00308         //: Then compile vector of all node indices with label value "near" this
00309         vcl_vector<unsigned  > index(labelCosts.size(),0) ;
00310         mbl_stl_increments(index.begin(),index.end(),0);
00311         //Insert all indices of elements = (or very close to) max value
00312         mbl_stl_copy_if(index.begin(),index.end(),
00313                         vcl_inserter(node_labels_subset[inode],node_labels_subset[inode].end()),
00314                         mbl_stl_pred_create_index_adapter(labelCosts,
00315                                                           mbl_stl_pred_is_near(lmax,epsilon_cost)));
00316     }
00317     for (unsigned arcId=0;arcId<arcs_.size();++arcId)
00318     {
00319         //Now find maximal edges by looping over all arcs
00320         double umax=-1.0E30;
00321         unsigned srcId=arcs_[arcId].min_v();
00322         unsigned targetId=arcs_[arcId].max_v();
00324         vnl_vector<double>& uToNeigh = u_[srcId][targetId];
00325         vnl_matrix<double>& linkCosts = arc_costs_phi_[srcId][targetId];
00327         unsigned nStates=linkCosts.rows();
00328         for (unsigned xlabel=0; xlabel<nStates;++xlabel) //Loop over labels of source node
00329         {
00330             double* pgRow=linkCosts[xlabel];
00331             double u = *(vcl_max_element(pgRow,pgRow+linkCosts.cols())); //max arc cost of pencil
00332             //And now look for max cost pencil
00333             umax = vcl_max(u,umax);
00334             uToNeigh[xlabel] = u;
00335         }
00338         vcl_vector<unsigned  > xindex(linkCosts.rows(),0) ;
00339         mbl_stl_increments(xindex.begin(),xindex.end(),0);
00340         vcl_vector<unsigned> maxRows;
00341         //Insert all indices of elements (source node labels) = (or very close to) max value
00342         mbl_stl_copy_if(xindex.begin(),xindex.end(),
00343                         vcl_back_inserter(maxRows),
00344                         mbl_stl_pred_create_index_adapter(uToNeigh,
00345                                                           mbl_stl_pred_is_near(umax,epsilon_cost)));
00346         vcl_vector<unsigned>::iterator rowIter=maxRows.begin();
00347         vcl_vector<unsigned>::iterator rowIterEnd=maxRows.end();
00348         while (rowIter != rowIterEnd)
00349         {
00350             //And for each such pencil locate the index of the maximising label to which it connects
00351             unsigned xlabel=*rowIter;
00352             vnl_vector_ref<double > row(linkCosts.cols(),linkCosts[xlabel]);
00353             mbl_stl_pred_is_near nearMax(umax,epsilon_cost);
00354             for (unsigned xprime=0;xprime<linkCosts.cols();++xprime)
00355             {
00356                 if (nearMax(row[xprime]))
00357                 {
00358                     links_subset[arcId].insert(vcl_pair<unsigned ,unsigned >(xlabel,xprime));
00359                 }
00360             }
00362             ++rowIter;
00363         }
00364     }//Arcs
00366     //CSP checker
00367     mmn_csp_solver cspSolver(nnodes_,arcs_);
00368     bool arcConsistent=cspSolver(node_labels_subset,links_subset);
00369     if (arcConsistent)
00370     {
00371         const vcl_vector<mmn_csp_solver::label_subset_t >& kernel_node_labels=cspSolver.kernel_node_labels();
00372         for (unsigned inode=0; inode<nnodes_;++inode)
00373         {
00374             x[inode]=*(kernel_node_labels[inode].begin());
00375         }
00376     }
00377     else
00378     {
00379         //Transformed problem is not arc consistent but fill up the solution vector anyway
00380         //Set to first maximal node label in each case
00381         for (unsigned inode=0; inode<nnodes_;++inode)
00382         {
00383             x[inode] = *(node_labels_subset[inode].begin());
00384         }
00385     }
00386     return arcConsistent;
00387 }
00389 bool mmn_diffusion_solver::continue_diffusion()
00390 {
00391     ++count_;
00392     bool retstate=true;
00393     if (max_delta_<epsilon_ || count_>max_iterations_)
00394     {
00395         //Terminate on either convergence or max iteration count reached
00396         retstate = false;
00397     }
00398     else if (count_>min_iterations_)
00399     {
00400         //Final convergence can be slow, but the final stages may well not affect the highest layer
00401         //So periodically check if we have reached an arc consistent top layer solution with non-increasing value
00402         if (count_ % gACS_CHECK_PERIOD==0)
00403         {
00404             vcl_vector<unsigned> x(nnodes_,0);
00405             bool ok = arc_consistent_solution(x);
00406             if (ok)
00407             {
00408                 double soln_val=solution_cost(x);
00409                 if (verbose_)
00410                 {
00411                     vcl_cout<<"Arc consistent solution reached. "
00412                             <<"\tSolution value= "<<soln_val<<"\tprev soln val= "<<soln_val_prev_<<vcl_endl;
00413                 }
00414                 if (vcl_fabs(soln_val-soln_val_prev_)<epsilon_)
00415                 {
00416                     ++nConverging_;
00417                     if (nConverging_>gNCONVERGED)
00418                     {
00419                         retstate = false;
00420                     }
00421                 }
00422                 else
00423                 {
00424                     nConverging_=0;
00425                 }
00427                 soln_val_prev_=soln_val;
00428             }
00429             else
00430             {
00431                 if (verbose_)
00432                 {
00433                     vcl_cout<<"Solution is not yet arc consistent."<<vcl_endl;
00434                 }
00435                 nConverging_=0;
00436             }
00437         }
00438     }
00439     return retstate;
00440 }