contrib/mul/mmn/mmn_lbp_solver.cxx
Go to the documentation of this file.
00001 #include "mmn_lbp_solver.h"
00002 #include <vcl_algorithm.h>
00003 #include <vcl_functional.h>
00004 #include <vcl_iterator.h>
00005 #include <vcl_sstream.h>
00006 #include <mbl/mbl_exception.h>
00007 #include <mbl/mbl_stl.h>
00008 #include <mbl/mbl_parse_block.h>
00009 #include <mbl/mbl_read_props.h>
00010 
00011 //:
00012 // \file
00013 // \brief Run loopy belief propagation to estimate maximum marginal probabilities of all node states
00014 // \author Martin Roberts
00015 
00016 const unsigned mmn_lbp_solver::NHISTORY_=5;
00017 const unsigned mmn_lbp_solver::NCYCLE_DETECT_=7;
00018 
00019 //: Default constructor
00020 mmn_lbp_solver::mmn_lbp_solver()
00021     : nnodes_(0), max_iterations_(100), min_simple_iterations_(25),
00022       epsilon_(1E-6), alpha_(0.6), smooth_on_cycling_(true),
00023       max_cycle_detection_count_(3), verbose_(false),
00024       msg_upd_mode_(mmn_lbp_solver::eRANDOM_SERIAL)
00025 {
00026     init();
00027 }
00028 
00029 //: Construct with arcs
00030 mmn_lbp_solver::mmn_lbp_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00031     : max_iterations_(100), min_simple_iterations_(25),
00032       epsilon_(1E-6), alpha_(0.6), smooth_on_cycling_(true),
00033       max_cycle_detection_count_(3), verbose_(false),
00034       msg_upd_mode_(mmn_lbp_solver::eRANDOM_SERIAL)
00035 {
00036     init();
00037     set_arcs(num_nodes,arcs);
00038 }
00039 
00040 void mmn_lbp_solver::init()
00041 {
00042     count_=0;
00043     max_delta_=-1.0;
00044     soln_history_.clear();
00045     max_delta_history_.clear();
00046     isCycling_ = false;
00047     nrevisits_=0;
00048     cycle_detection_count_=0;
00049     zbest_on_cycle_detection_=0.0;
00050 }
00051 
00052 //: Pass in the arcs, which are then used to build the graph object
00053 void mmn_lbp_solver::set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00054 {
00055     nnodes_=num_nodes;
00056     arcs_ = arcs;
00057     //Verify consistency
00058     unsigned max_node=0;
00059     for (unsigned i=0; i<arcs.size();++i)
00060     {
00061         max_node=vcl_max(max_node,arcs[i].max_v());
00062     }
00063     if (nnodes_ != max_node+1)
00064     {
00065         vcl_cerr<<"Arcs appear to be inconsistent with number of nodes in mmn_lbp_solver::set_arcs\n"
00066                 <<"Max mode in Arcs is: "<<max_node<<" but number of nodes= "<<nnodes_<<'\n';
00067     }
00068 
00069     graph_.build(nnodes_,arcs_);
00070 
00071     messages_.resize(nnodes_);
00072     messages_upd_=messages_;
00073     arc_costs_.resize(nnodes_);
00074 
00075     //Set max iterations, somewhat arbitrarily, increasing with nodes and arcs
00076     max_iterations_ = min_simple_iterations_ + nnodes_ + arcs_.size();
00077 }
00078 
00079 double mmn_lbp_solver::solve(
00080                  const vcl_vector<vnl_vector<double> >& node_cost,
00081                  const vcl_vector<vnl_matrix<double> >& pair_cost,
00082                  vcl_vector<unsigned>& x)
00083 {
00084     return (*this)(node_cost,pair_cost,x);
00085 }
00086 
00087 //: Run the algorithm
00088 double mmn_lbp_solver::operator()(const vcl_vector<vnl_vector<double> >& node_costs,
00089                                   const vcl_vector<vnl_matrix<double> >& pair_costs,
00090                                   vcl_vector<unsigned>& x)
00091 {
00092     init();
00093 
00094     x.resize(nnodes_);
00095     vcl_fill(x.begin(),x.end(),0);
00096     belief_.resize(nnodes_);
00097 
00098     node_costs_.resize(nnodes_);
00099     for (unsigned i=0;i<nnodes_;++i)
00100     {
00101         //: Negate costs to convert to log prob (i.e. internally we do maximisation)
00102         node_costs_[i] = - node_costs[i];
00103     }
00104 
00105     //Initialise message structure and neighbourhood cost representation
00106     const vcl_vector<vcl_vector<vcl_pair<unsigned,unsigned> > >& neighbourhoods=graph_.node_data();
00107     for (unsigned inode=0; inode<neighbourhoods.size();++inode)
00108     {
00109         unsigned nbstates=node_costs_[inode].size();
00110         belief_[inode].set_size(nbstates);
00111 
00112         double priorb=vcl_log(1.0/double(nbstates));
00113         belief_[inode].fill(priorb);
00114 
00115         const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=neighbourhoods[inode];
00116         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00117         vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00118         while (neighIter != neighIterEnd)
00119         {
00120             unsigned arcId=neighIter->second;
00121             vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00122             const vnl_matrix<double >& srcArcCosts=pair_costs[arcId];
00123             mmn_arc& arc=arcs_[arcId];
00124             unsigned v1=arc.v1;
00125             unsigned v2=arc.v2;
00126             unsigned minv=arc.min_v();
00127             if (inode!=v1 && inode!=v2)
00128             {
00129                 vcl_string msg("Graph inconsistency in mmn_lbp_solver::operator()\n");
00130                 vcl_ostringstream os;
00131                 os <<"Source node is "<<inode<<" but arc to alleged neighbour joins nodes "<<v1<<"\t to "<<v2<<'\n';
00132                 msg+= os.str();
00133                 vcl_cerr<<msg<<vcl_endl;
00134                 throw mbl_exception_abort(msg);
00135             }
00136 
00137             if (inode==minv)
00138             {
00139                 linkCosts=srcArcCosts;
00140             }
00141             else //transpose arc costs as this node is the 2nd element in the pair_costs matrix
00142             {
00143                 linkCosts=srcArcCosts.transpose();
00144             }
00145             linkCosts*= -1.0; //convert to maximising log prob (not min -log prob)
00146             unsigned nstates=linkCosts.cols();
00147             double dnstates=double(nstates);
00148             messages_[inode][neighIter->first].set_size(nstates);
00149             messages_[inode][neighIter->first].fill(vcl_log(1.0/dnstates)); //set all initial messages to uniform prob
00150 
00151             ++neighIter; //next neighbour of this node
00152         }
00153     } //next node
00154 
00155     messages_upd_ = messages_;
00156 
00157     //Now keep repeating message passing
00158     vcl_vector<unsigned > random_indices(nnodes_,0);
00159     mbl_stl_increments(random_indices.begin(),random_indices.end(),0);
00160 
00161     do
00162     {
00163         max_delta_=-1.0;
00164         switch (msg_upd_mode_)
00165         {
00166             case eALL_PARALLEL:
00167             {
00168                 //Calculate all updates in parallel using only previous iteration messages
00169                 for (unsigned inode=0; inode<nnodes_;++inode)
00170                 {
00171                     update_messages_to_neighbours(inode,node_costs_[inode]);
00172                 }
00173                 messages_ = messages_upd_;
00174             }
00175             break;
00176 
00177             case eRANDOM_SERIAL:
00178             default:
00179             {
00180                 vcl_random_shuffle(random_indices.begin(),random_indices.end());
00181                 //Randomise the order of inter-node messages
00182                 //May help avoid looping
00183                 for (unsigned knode=0; knode<nnodes_;++knode)
00184                 {
00185                     unsigned inode=random_indices[knode];
00186                     update_messages_to_neighbours(inode,node_costs_[inode]);
00187                     messages_[inode] = messages_upd_[inode]; //immediate update for this node
00188                 }
00189             }
00190         }
00191 
00192         if (verbose_)
00193         {
00194             vcl_cout<<"Max message delta at iteration "<<count_<<"\t is "<<max_delta_<<vcl_endl;
00195         }
00196         //Now calculate belief levels of each node's states
00197         calculate_beliefs(x);
00198     }
00199     while (continue_propagation(x));
00200 
00201     //Now calculate final belief levels of each node's states and select the maximising ones
00202     calculate_beliefs(x);
00203 
00204     for (unsigned inode=0; inode<nnodes_;++inode)
00205     {
00206         renormalise_log(belief_[inode]);
00207         for (unsigned i=0; i<belief_[inode].size();i++)
00208         {
00209             belief_[inode][i]=vcl_exp(belief_[inode][i]);
00210         }
00211     }
00212 
00213     //Return -best solution value (i.e. minimised form)
00214     //vcl_cout<<"Calculating solution cost..."<<vcl_endl;
00215 
00216     if (!isCycling_)
00217     {
00218         return -solution_cost(x);
00219     }
00220     else
00221     {
00222         double zbest=best_solution_cost_in_history(x);
00223         if (verbose_)
00224         {
00225             vcl_cout<<"Best solution when cycling condition first detected was: "<<zbest_on_cycle_detection_<<vcl_endl
00226                     <<"Final Best solution : "<<zbest<<vcl_endl;
00227         }
00228         return -zbest;
00229     }
00230 }
00231 
00232 void mmn_lbp_solver::calculate_beliefs(vcl_vector<unsigned>& x)
00233 {
00234     //Now calculate belief levels of each node's states
00235     //NB calculates log belief actually
00236 
00237     const vcl_vector<vcl_vector<vcl_pair<unsigned,unsigned> > >& neighbourhoods=graph_.node_data();
00238     for (unsigned inode=0; inode<neighbourhoods.size();++inode)
00239     {
00240         unsigned bestState=0;
00241         double best=-1.0E012;
00242         unsigned nstates=node_costs_[inode].size();
00243         for (unsigned istate=0; istate<nstates;++istate)
00244         {
00245             double b=node_costs_[inode][istate];
00246             //Now loop over neighbourhood
00247             const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00248             vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00249             vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00250             while (neighIter != neighIterEnd)
00251             {
00252                 vnl_vector<double>& msgsFromNeigh = messages_[neighIter->first][inode];
00253                 b+= msgsFromNeigh[istate];
00254                 ++neighIter;
00255             }
00256             belief_[inode][istate]=b;
00257             if (b>best)
00258             {
00259                 best=b;
00260                 bestState=istate;
00261             }
00262         }
00263         x[inode]=bestState;
00264 
00265         renormalise_log(belief_[inode]);
00266     }
00267 }
00268 
00269 double mmn_lbp_solver::solution_cost(vcl_vector<unsigned>& x)
00270 {
00271     //: Calculate best (max log prob) of solution x
00272     double sumNodes=0.0;
00273     //Sum over all nodes
00274     vcl_vector<vnl_vector<double> >::const_iterator nodeIter=node_costs_.begin();
00275     vcl_vector<vnl_vector<double> >::const_iterator nodeIterEnd=node_costs_.end();
00276     vcl_vector<unsigned >::const_iterator stateIter=x.begin();
00277     while (nodeIter != nodeIterEnd)
00278     {
00279         const vnl_vector<double>& ncosts = *nodeIter;
00280         sumNodes+=ncosts[*stateIter];
00281         ++nodeIter;++stateIter;
00282     }
00283 
00284     // Sum over all arcs
00285     vcl_vector<mmn_arc>::const_iterator arcIter=arcs_.begin();
00286     vcl_vector<mmn_arc>::const_iterator arcIterEnd=arcs_.end();
00287     double sumArcs=0.0;
00288     while (arcIter != arcIterEnd)
00289     {
00290         unsigned nodeId1=arcIter->v1;
00291         unsigned nodeId2=arcIter->v2;
00292         sumArcs += arc_costs_[nodeId1][nodeId2](x[nodeId1],x[nodeId2]);
00293         ++arcIter;
00294     }
00295 
00296     return sumNodes+sumArcs;
00297 }
00298 
00299 double mmn_lbp_solver::best_solution_cost_in_history(vcl_vector<unsigned>& x)
00300 {
00301     double zbest=solution_cost(x);
00302     vcl_vector<double> solution_vals(soln_history_.size());
00303     vcl_deque<vcl_vector<unsigned> >::iterator xIter=soln_history_.begin();
00304     vcl_deque<vcl_vector<unsigned> >::iterator xIterEnd=soln_history_.end();
00305     vcl_deque<vcl_vector<unsigned> >::iterator xIterBest=soln_history_.end()-1;
00306     while (xIter != xIterEnd)
00307     {
00308         double z = solution_cost(*xIter);
00309         if (z>zbest)
00310         {
00311             zbest=z;
00312             xIterBest=xIter;
00313         }
00314         ++xIter;
00315     }
00316     x=*xIterBest;
00317     return zbest;
00318 }
00319 
00320 void mmn_lbp_solver::update_messages_to_neighbours(unsigned inode,
00321                                                    const vnl_vector<double>& node_cost)
00322 {
00323     //Update all messages from this node to its neighbours
00324 
00325     const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00326 
00327     vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00328     vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00329     while (neighIter != neighIterEnd) //Loop over all my neighbours
00330     {
00331         vnl_vector<double>& msgsToNeigh = messages_[inode][neighIter->first];
00332         vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00333         unsigned nTargetStates=msgsToNeigh.size();
00334         unsigned nSrcStates=linkCosts.rows(); //number of source states for this node
00335         if (nSrcStates!=node_cost.size())
00336         {
00337             vcl_string msg("Inconsistent array sizes in mmn_lbp_solver::update_messages_to_neighbours\n");
00338             msg+= "Inconsistent array sizes in mmn_lbp_solver::update_messages_to_neighbours ";
00339 
00340             vcl_cerr<<msg<<vcl_endl;
00341             throw mbl_exception_abort(msg);
00342         }
00343         for (unsigned jstate=0;jstate<nTargetStates;++jstate) //do each state of the target neighbour
00344         {
00345             double max_istates= -1E99; // minus infinity, as initialisation for a maximum
00346             for (unsigned istate=0; istate<nSrcStates;++istate)
00347             {
00348                 double logProdIncoming=0.0;
00349                 //Compute product of all incoming messages to this node i from elsewhere (excluding target node j)
00350                 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator tomeIter=neighbours.begin();
00351                 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator tomeIterEnd=neighbours.end();
00352                 while (tomeIter != tomeIterEnd)
00353                 {
00354                     if (tomeIter != neighIter)
00355                     {
00356                         unsigned k=tomeIter->first;
00357                         double m_ki_xi=messages_[k][inode][istate];
00358                         logProdIncoming += m_ki_xi;
00359                     }
00360                     ++tomeIter;
00361                 }
00362                 double acost=linkCosts(istate,jstate);
00363                 double ncost=node_cost[istate];
00364                 double logMij=acost+ncost+logProdIncoming;
00365                 max_istates = vcl_max(max_istates,logMij);
00366             }
00367             if (cycle_detection_count_>0 && smooth_on_cycling_)
00368             {
00369                 messages_upd_[inode][neighIter->first][jstate]=alpha_*max_istates+(1.0-alpha_)*messages_[inode][neighIter->first][jstate];
00370             }
00371             else
00372             {
00373                 messages_upd_[inode][neighIter->first][jstate]=max_istates;
00374             }
00375         }
00376         renormalise_log(messages_upd_[inode][neighIter->first]);
00377 #if 0
00378         vcl_cout<<"Iteration "<<count_<<"Msg Upd Node\t"<<inode<<"\tto node "<<neighIter->first<<'\t';
00379 
00380         vcl_copy(messages_upd_[inode][neighIter->first].begin(),
00381                  messages_upd_[inode][neighIter->first].end(),
00382                  vcl_ostream_iterator<double>(vcl_cout,"\t"));
00383         vcl_cout<<vcl_endl
00384                 <<"Iteration "<<count_<<"Msg Prv Node\t"<<inode<<"\tto node "<<neighIter->first<<'\t';
00385 
00386         vcl_copy(messages_[inode][neighIter->first].begin(),
00387                  messages_[inode][neighIter->first].end(),
00388                  vcl_ostream_iterator<double>(vcl_cout,"\t"));
00389 
00390         vcl_cout<<vcl_endl<<vcl_endl;
00391 #endif
00392 
00393         //: Compute max change during iteration
00394         vnl_vector<double > delta_message=(messages_upd_[inode][neighIter->first]-
00395                                            messages_[inode][neighIter->first]);
00396         double delta=delta_message.inf_norm();
00397         max_delta_ = vcl_max(max_delta_,delta);
00398 
00399         ++neighIter;
00400     }
00401 }
00402 
00403 void mmn_lbp_solver::renormalise_log(vnl_vector<double >& logMessageVec)
00404 {
00405     vnl_vector<double >::iterator stateIter=logMessageVec.begin();
00406     vnl_vector<double >::iterator stateIterEnd=logMessageVec.end();
00407     double probSum=0.0;
00408     while (stateIter != stateIterEnd)
00409     {
00410         probSum+=vcl_exp(*stateIter);
00411         ++stateIter;
00412     }
00413 
00414     //normalise so probabilities sum to 1
00415     double alpha = 1.0/probSum;
00416     //But now rather than multiplying by alpha, add log(alpha);
00417     double logAlpha=vcl_log(alpha);
00418     vcl_transform(logMessageVec.begin(),logMessageVec.end(),
00419                   logMessageVec.begin(),
00420                   vcl_bind2nd(vcl_plus<double>(),logAlpha));
00421 }
00422 
00423 bool mmn_lbp_solver::continue_propagation(vcl_vector<unsigned>& x)
00424 {
00425     ++count_;
00426     bool retstate=true;
00427     if (max_delta_<epsilon_ || count_>max_iterations_)
00428     {
00429         //Terminate on either convergence or max iteration count reached
00430         retstate = false;
00431     }
00432     else if (count_ < min_simple_iterations_)
00433     {
00434         //always do at least this many if not converged in delta
00435         retstate = true;
00436     }
00437     else if (cycle_detection_count_<2 &&
00438              vcl_count_if(max_delta_history_.begin(),max_delta_history_.end(),
00439                           vcl_bind1st(vcl_less<double >(),max_delta_))
00440              == int(max_delta_history_.size()))
00441     {
00442         retstate =true; //delta is definitely decreasing so keep going unless we've had >2 cycles already
00443     }
00444     else
00445     {
00446         isCycling_=false;
00447         //Check for cycling condition
00448         vcl_deque<vcl_vector<unsigned  > >::iterator finder=vcl_find(soln_history_.begin(),soln_history_.end(),x);
00449         if (finder != soln_history_.end())
00450         {
00451             ++nrevisits_;
00452         }
00453         else
00454         {
00455             nrevisits_=0;
00456         }
00457         if (nrevisits_>NCYCLE_DETECT_)
00458         {
00459             isCycling_=true;
00460             ++cycle_detection_count_;
00461             vcl_cout<<"!!!! Loopy Belief is CYCLING... "<<vcl_endl;
00462         }
00463         if (isCycling_)
00464         {
00465             if (cycle_detection_count_==1)
00466             {
00467                 vcl_vector<unsigned > xdummy=x;
00468                 zbest_on_cycle_detection_=best_solution_cost_in_history(xdummy);
00469             }
00470             if (smooth_on_cycling_ && cycle_detection_count_<max_cycle_detection_count_)
00471             {
00472                 nrevisits_=0;
00473                 vcl_cout<<"Initiating message alpha smoothing to try and break cycling..."<<vcl_endl;
00474                 soln_history_.clear();
00475             }
00476             else
00477             {
00478                 vcl_cout<<"Abort and pick best solution in history."<<vcl_endl;
00479                 retstate= false;
00480             }
00481         }
00482     }
00483 
00484     max_delta_history_.push_back(max_delta_);
00485     if (max_delta_history_.size()>NHISTORY_)
00486     {
00487         max_delta_history_.pop_front();
00488     }
00489     soln_history_.push_back(x);
00490     if (soln_history_.size()>NHISTORY_)
00491     {
00492         soln_history_.pop_front();
00493     }
00494 
00495     return retstate;
00496 }
00497 
00498 
00499 //=======================================================================
00500 // Method: set_from_stream
00501 //=======================================================================
00502 //: Initialise from a string stream
00503 bool mmn_lbp_solver::set_from_stream(vcl_istream &is)
00504 {
00505   // Cycle through stream and produce a map of properties
00506   vcl_string s = mbl_parse_block(is);
00507   vcl_istringstream ss(s);
00508   mbl_read_props_type props = mbl_read_props_ws(ss);
00509 
00510   // No properties expected.
00511 
00512   // Check for unused props
00513   mbl_read_props_look_for_unused_props(
00514       "mmn_lbp_solver::set_from_stream", props, mbl_read_props_type());
00515   return true;
00516 }
00517 
00518 
00519 //=======================================================================
00520 // Method: version_no
00521 //=======================================================================
00522 
00523 short mmn_lbp_solver::version_no() const
00524 {
00525   return 1;
00526 }
00527 
00528 //=======================================================================
00529 // Method: is_a
00530 //=======================================================================
00531 
00532 vcl_string mmn_lbp_solver::is_a() const
00533 {
00534   return vcl_string("mmn_lbp_solver");
00535 }
00536 
00537 //: Create a copy on the heap and return base class pointer
00538 mmn_solver* mmn_lbp_solver::clone() const
00539 {
00540   return new mmn_lbp_solver(*this);
00541 }
00542 
00543 //=======================================================================
00544 // Method: print
00545 //=======================================================================
00546 
00547 void mmn_lbp_solver::print_summary(vcl_ostream& os) const
00548 {
00549     os<<"This is a "<<is_a()<<'\t'<<"with "<<nnodes_<<" nodes"<<vcl_endl;
00550 }
00551 
00552 //=======================================================================
00553 // Method: save
00554 //=======================================================================
00555 
00556 void mmn_lbp_solver::b_write(vsl_b_ostream& bfs) const
00557 {
00558   vsl_b_write(bfs,version_no());
00559 }
00560 
00561 //=======================================================================
00562 // Method: load
00563 //=======================================================================
00564 
00565 void mmn_lbp_solver::b_read(vsl_b_istream& bfs)
00566 {
00567   if (!bfs) return;
00568   short version;
00569   vsl_b_read(bfs,version);
00570   switch (version)
00571   {
00572     case (1):
00573       break;
00574     default:
00575       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00576                << "           Unknown version number "<< version << vcl_endl;
00577       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00578       return;
00579   }
00580 }
00581