00001 #include "mmn_lbp_solver.h"
00002 #include <vcl_algorithm.h>
00003 #include <vcl_functional.h>
00004 #include <vcl_iterator.h>
00005 #include <vcl_sstream.h>
00006 #include <mbl/mbl_exception.h>
00007 #include <mbl/mbl_stl.h>
00008 #include <mbl/mbl_parse_block.h>
00009 #include <mbl/mbl_read_props.h>
00010
00011
00012
00013
00014
00015
00016 const unsigned mmn_lbp_solver::NHISTORY_=5;
00017 const unsigned mmn_lbp_solver::NCYCLE_DETECT_=7;
00018
00019
00020 mmn_lbp_solver::mmn_lbp_solver()
00021 : nnodes_(0), max_iterations_(100), min_simple_iterations_(25),
00022 epsilon_(1E-6), alpha_(0.6), smooth_on_cycling_(true),
00023 max_cycle_detection_count_(3), verbose_(false),
00024 msg_upd_mode_(mmn_lbp_solver::eRANDOM_SERIAL)
00025 {
00026 init();
00027 }
00028
00029
00030 mmn_lbp_solver::mmn_lbp_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00031 : max_iterations_(100), min_simple_iterations_(25),
00032 epsilon_(1E-6), alpha_(0.6), smooth_on_cycling_(true),
00033 max_cycle_detection_count_(3), verbose_(false),
00034 msg_upd_mode_(mmn_lbp_solver::eRANDOM_SERIAL)
00035 {
00036 init();
00037 set_arcs(num_nodes,arcs);
00038 }
00039
00040 void mmn_lbp_solver::init()
00041 {
00042 count_=0;
00043 max_delta_=-1.0;
00044 soln_history_.clear();
00045 max_delta_history_.clear();
00046 isCycling_ = false;
00047 nrevisits_=0;
00048 cycle_detection_count_=0;
00049 zbest_on_cycle_detection_=0.0;
00050 }
00051
00052
00053 void mmn_lbp_solver::set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs)
00054 {
00055 nnodes_=num_nodes;
00056 arcs_ = arcs;
00057
00058 unsigned max_node=0;
00059 for (unsigned i=0; i<arcs.size();++i)
00060 {
00061 max_node=vcl_max(max_node,arcs[i].max_v());
00062 }
00063 if (nnodes_ != max_node+1)
00064 {
00065 vcl_cerr<<"Arcs appear to be inconsistent with number of nodes in mmn_lbp_solver::set_arcs\n"
00066 <<"Max mode in Arcs is: "<<max_node<<" but number of nodes= "<<nnodes_<<'\n';
00067 }
00068
00069 graph_.build(nnodes_,arcs_);
00070
00071 messages_.resize(nnodes_);
00072 messages_upd_=messages_;
00073 arc_costs_.resize(nnodes_);
00074
00075
00076 max_iterations_ = min_simple_iterations_ + nnodes_ + arcs_.size();
00077 }
00078
00079 double mmn_lbp_solver::solve(
00080 const vcl_vector<vnl_vector<double> >& node_cost,
00081 const vcl_vector<vnl_matrix<double> >& pair_cost,
00082 vcl_vector<unsigned>& x)
00083 {
00084 return (*this)(node_cost,pair_cost,x);
00085 }
00086
00087
00088 double mmn_lbp_solver::operator()(const vcl_vector<vnl_vector<double> >& node_costs,
00089 const vcl_vector<vnl_matrix<double> >& pair_costs,
00090 vcl_vector<unsigned>& x)
00091 {
00092 init();
00093
00094 x.resize(nnodes_);
00095 vcl_fill(x.begin(),x.end(),0);
00096 belief_.resize(nnodes_);
00097
00098 node_costs_.resize(nnodes_);
00099 for (unsigned i=0;i<nnodes_;++i)
00100 {
00101
00102 node_costs_[i] = - node_costs[i];
00103 }
00104
00105
00106 const vcl_vector<vcl_vector<vcl_pair<unsigned,unsigned> > >& neighbourhoods=graph_.node_data();
00107 for (unsigned inode=0; inode<neighbourhoods.size();++inode)
00108 {
00109 unsigned nbstates=node_costs_[inode].size();
00110 belief_[inode].set_size(nbstates);
00111
00112 double priorb=vcl_log(1.0/double(nbstates));
00113 belief_[inode].fill(priorb);
00114
00115 const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=neighbourhoods[inode];
00116 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00117 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00118 while (neighIter != neighIterEnd)
00119 {
00120 unsigned arcId=neighIter->second;
00121 vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00122 const vnl_matrix<double >& srcArcCosts=pair_costs[arcId];
00123 mmn_arc& arc=arcs_[arcId];
00124 unsigned v1=arc.v1;
00125 unsigned v2=arc.v2;
00126 unsigned minv=arc.min_v();
00127 if (inode!=v1 && inode!=v2)
00128 {
00129 vcl_string msg("Graph inconsistency in mmn_lbp_solver::operator()\n");
00130 vcl_ostringstream os;
00131 os <<"Source node is "<<inode<<" but arc to alleged neighbour joins nodes "<<v1<<"\t to "<<v2<<'\n';
00132 msg+= os.str();
00133 vcl_cerr<<msg<<vcl_endl;
00134 throw mbl_exception_abort(msg);
00135 }
00136
00137 if (inode==minv)
00138 {
00139 linkCosts=srcArcCosts;
00140 }
00141 else
00142 {
00143 linkCosts=srcArcCosts.transpose();
00144 }
00145 linkCosts*= -1.0;
00146 unsigned nstates=linkCosts.cols();
00147 double dnstates=double(nstates);
00148 messages_[inode][neighIter->first].set_size(nstates);
00149 messages_[inode][neighIter->first].fill(vcl_log(1.0/dnstates));
00150
00151 ++neighIter;
00152 }
00153 }
00154
00155 messages_upd_ = messages_;
00156
00157
00158 vcl_vector<unsigned > random_indices(nnodes_,0);
00159 mbl_stl_increments(random_indices.begin(),random_indices.end(),0);
00160
00161 do
00162 {
00163 max_delta_=-1.0;
00164 switch (msg_upd_mode_)
00165 {
00166 case eALL_PARALLEL:
00167 {
00168
00169 for (unsigned inode=0; inode<nnodes_;++inode)
00170 {
00171 update_messages_to_neighbours(inode,node_costs_[inode]);
00172 }
00173 messages_ = messages_upd_;
00174 }
00175 break;
00176
00177 case eRANDOM_SERIAL:
00178 default:
00179 {
00180 vcl_random_shuffle(random_indices.begin(),random_indices.end());
00181
00182
00183 for (unsigned knode=0; knode<nnodes_;++knode)
00184 {
00185 unsigned inode=random_indices[knode];
00186 update_messages_to_neighbours(inode,node_costs_[inode]);
00187 messages_[inode] = messages_upd_[inode];
00188 }
00189 }
00190 }
00191
00192 if (verbose_)
00193 {
00194 vcl_cout<<"Max message delta at iteration "<<count_<<"\t is "<<max_delta_<<vcl_endl;
00195 }
00196
00197 calculate_beliefs(x);
00198 }
00199 while (continue_propagation(x));
00200
00201
00202 calculate_beliefs(x);
00203
00204 for (unsigned inode=0; inode<nnodes_;++inode)
00205 {
00206 renormalise_log(belief_[inode]);
00207 for (unsigned i=0; i<belief_[inode].size();i++)
00208 {
00209 belief_[inode][i]=vcl_exp(belief_[inode][i]);
00210 }
00211 }
00212
00213
00214
00215
00216 if (!isCycling_)
00217 {
00218 return -solution_cost(x);
00219 }
00220 else
00221 {
00222 double zbest=best_solution_cost_in_history(x);
00223 if (verbose_)
00224 {
00225 vcl_cout<<"Best solution when cycling condition first detected was: "<<zbest_on_cycle_detection_<<vcl_endl
00226 <<"Final Best solution : "<<zbest<<vcl_endl;
00227 }
00228 return -zbest;
00229 }
00230 }
00231
00232 void mmn_lbp_solver::calculate_beliefs(vcl_vector<unsigned>& x)
00233 {
00234
00235
00236
00237 const vcl_vector<vcl_vector<vcl_pair<unsigned,unsigned> > >& neighbourhoods=graph_.node_data();
00238 for (unsigned inode=0; inode<neighbourhoods.size();++inode)
00239 {
00240 unsigned bestState=0;
00241 double best=-1.0E012;
00242 unsigned nstates=node_costs_[inode].size();
00243 for (unsigned istate=0; istate<nstates;++istate)
00244 {
00245 double b=node_costs_[inode][istate];
00246
00247 const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00248 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00249 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00250 while (neighIter != neighIterEnd)
00251 {
00252 vnl_vector<double>& msgsFromNeigh = messages_[neighIter->first][inode];
00253 b+= msgsFromNeigh[istate];
00254 ++neighIter;
00255 }
00256 belief_[inode][istate]=b;
00257 if (b>best)
00258 {
00259 best=b;
00260 bestState=istate;
00261 }
00262 }
00263 x[inode]=bestState;
00264
00265 renormalise_log(belief_[inode]);
00266 }
00267 }
00268
00269 double mmn_lbp_solver::solution_cost(vcl_vector<unsigned>& x)
00270 {
00271
00272 double sumNodes=0.0;
00273
00274 vcl_vector<vnl_vector<double> >::const_iterator nodeIter=node_costs_.begin();
00275 vcl_vector<vnl_vector<double> >::const_iterator nodeIterEnd=node_costs_.end();
00276 vcl_vector<unsigned >::const_iterator stateIter=x.begin();
00277 while (nodeIter != nodeIterEnd)
00278 {
00279 const vnl_vector<double>& ncosts = *nodeIter;
00280 sumNodes+=ncosts[*stateIter];
00281 ++nodeIter;++stateIter;
00282 }
00283
00284
00285 vcl_vector<mmn_arc>::const_iterator arcIter=arcs_.begin();
00286 vcl_vector<mmn_arc>::const_iterator arcIterEnd=arcs_.end();
00287 double sumArcs=0.0;
00288 while (arcIter != arcIterEnd)
00289 {
00290 unsigned nodeId1=arcIter->v1;
00291 unsigned nodeId2=arcIter->v2;
00292 sumArcs += arc_costs_[nodeId1][nodeId2](x[nodeId1],x[nodeId2]);
00293 ++arcIter;
00294 }
00295
00296 return sumNodes+sumArcs;
00297 }
00298
00299 double mmn_lbp_solver::best_solution_cost_in_history(vcl_vector<unsigned>& x)
00300 {
00301 double zbest=solution_cost(x);
00302 vcl_vector<double> solution_vals(soln_history_.size());
00303 vcl_deque<vcl_vector<unsigned> >::iterator xIter=soln_history_.begin();
00304 vcl_deque<vcl_vector<unsigned> >::iterator xIterEnd=soln_history_.end();
00305 vcl_deque<vcl_vector<unsigned> >::iterator xIterBest=soln_history_.end()-1;
00306 while (xIter != xIterEnd)
00307 {
00308 double z = solution_cost(*xIter);
00309 if (z>zbest)
00310 {
00311 zbest=z;
00312 xIterBest=xIter;
00313 }
00314 ++xIter;
00315 }
00316 x=*xIterBest;
00317 return zbest;
00318 }
00319
00320 void mmn_lbp_solver::update_messages_to_neighbours(unsigned inode,
00321 const vnl_vector<double>& node_cost)
00322 {
00323
00324
00325 const vcl_vector<vcl_pair<unsigned,unsigned> >& neighbours=graph_.node_data()[inode];
00326
00327 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIter=neighbours.begin();
00328 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator neighIterEnd=neighbours.end();
00329 while (neighIter != neighIterEnd)
00330 {
00331 vnl_vector<double>& msgsToNeigh = messages_[inode][neighIter->first];
00332 vnl_matrix<double>& linkCosts = arc_costs_[inode][neighIter->first];
00333 unsigned nTargetStates=msgsToNeigh.size();
00334 unsigned nSrcStates=linkCosts.rows();
00335 if (nSrcStates!=node_cost.size())
00336 {
00337 vcl_string msg("Inconsistent array sizes in mmn_lbp_solver::update_messages_to_neighbours\n");
00338 msg+= "Inconsistent array sizes in mmn_lbp_solver::update_messages_to_neighbours ";
00339
00340 vcl_cerr<<msg<<vcl_endl;
00341 throw mbl_exception_abort(msg);
00342 }
00343 for (unsigned jstate=0;jstate<nTargetStates;++jstate)
00344 {
00345 double max_istates= -1E99;
00346 for (unsigned istate=0; istate<nSrcStates;++istate)
00347 {
00348 double logProdIncoming=0.0;
00349
00350 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator tomeIter=neighbours.begin();
00351 vcl_vector<vcl_pair<unsigned,unsigned> >::const_iterator tomeIterEnd=neighbours.end();
00352 while (tomeIter != tomeIterEnd)
00353 {
00354 if (tomeIter != neighIter)
00355 {
00356 unsigned k=tomeIter->first;
00357 double m_ki_xi=messages_[k][inode][istate];
00358 logProdIncoming += m_ki_xi;
00359 }
00360 ++tomeIter;
00361 }
00362 double acost=linkCosts(istate,jstate);
00363 double ncost=node_cost[istate];
00364 double logMij=acost+ncost+logProdIncoming;
00365 max_istates = vcl_max(max_istates,logMij);
00366 }
00367 if (cycle_detection_count_>0 && smooth_on_cycling_)
00368 {
00369 messages_upd_[inode][neighIter->first][jstate]=alpha_*max_istates+(1.0-alpha_)*messages_[inode][neighIter->first][jstate];
00370 }
00371 else
00372 {
00373 messages_upd_[inode][neighIter->first][jstate]=max_istates;
00374 }
00375 }
00376 renormalise_log(messages_upd_[inode][neighIter->first]);
00377 #if 0
00378 vcl_cout<<"Iteration "<<count_<<"Msg Upd Node\t"<<inode<<"\tto node "<<neighIter->first<<'\t';
00379
00380 vcl_copy(messages_upd_[inode][neighIter->first].begin(),
00381 messages_upd_[inode][neighIter->first].end(),
00382 vcl_ostream_iterator<double>(vcl_cout,"\t"));
00383 vcl_cout<<vcl_endl
00384 <<"Iteration "<<count_<<"Msg Prv Node\t"<<inode<<"\tto node "<<neighIter->first<<'\t';
00385
00386 vcl_copy(messages_[inode][neighIter->first].begin(),
00387 messages_[inode][neighIter->first].end(),
00388 vcl_ostream_iterator<double>(vcl_cout,"\t"));
00389
00390 vcl_cout<<vcl_endl<<vcl_endl;
00391 #endif
00392
00393
00394 vnl_vector<double > delta_message=(messages_upd_[inode][neighIter->first]-
00395 messages_[inode][neighIter->first]);
00396 double delta=delta_message.inf_norm();
00397 max_delta_ = vcl_max(max_delta_,delta);
00398
00399 ++neighIter;
00400 }
00401 }
00402
00403 void mmn_lbp_solver::renormalise_log(vnl_vector<double >& logMessageVec)
00404 {
00405 vnl_vector<double >::iterator stateIter=logMessageVec.begin();
00406 vnl_vector<double >::iterator stateIterEnd=logMessageVec.end();
00407 double probSum=0.0;
00408 while (stateIter != stateIterEnd)
00409 {
00410 probSum+=vcl_exp(*stateIter);
00411 ++stateIter;
00412 }
00413
00414
00415 double alpha = 1.0/probSum;
00416
00417 double logAlpha=vcl_log(alpha);
00418 vcl_transform(logMessageVec.begin(),logMessageVec.end(),
00419 logMessageVec.begin(),
00420 vcl_bind2nd(vcl_plus<double>(),logAlpha));
00421 }
00422
00423 bool mmn_lbp_solver::continue_propagation(vcl_vector<unsigned>& x)
00424 {
00425 ++count_;
00426 bool retstate=true;
00427 if (max_delta_<epsilon_ || count_>max_iterations_)
00428 {
00429
00430 retstate = false;
00431 }
00432 else if (count_ < min_simple_iterations_)
00433 {
00434
00435 retstate = true;
00436 }
00437 else if (cycle_detection_count_<2 &&
00438 vcl_count_if(max_delta_history_.begin(),max_delta_history_.end(),
00439 vcl_bind1st(vcl_less<double >(),max_delta_))
00440 == int(max_delta_history_.size()))
00441 {
00442 retstate =true;
00443 }
00444 else
00445 {
00446 isCycling_=false;
00447
00448 vcl_deque<vcl_vector<unsigned > >::iterator finder=vcl_find(soln_history_.begin(),soln_history_.end(),x);
00449 if (finder != soln_history_.end())
00450 {
00451 ++nrevisits_;
00452 }
00453 else
00454 {
00455 nrevisits_=0;
00456 }
00457 if (nrevisits_>NCYCLE_DETECT_)
00458 {
00459 isCycling_=true;
00460 ++cycle_detection_count_;
00461 vcl_cout<<"!!!! Loopy Belief is CYCLING... "<<vcl_endl;
00462 }
00463 if (isCycling_)
00464 {
00465 if (cycle_detection_count_==1)
00466 {
00467 vcl_vector<unsigned > xdummy=x;
00468 zbest_on_cycle_detection_=best_solution_cost_in_history(xdummy);
00469 }
00470 if (smooth_on_cycling_ && cycle_detection_count_<max_cycle_detection_count_)
00471 {
00472 nrevisits_=0;
00473 vcl_cout<<"Initiating message alpha smoothing to try and break cycling..."<<vcl_endl;
00474 soln_history_.clear();
00475 }
00476 else
00477 {
00478 vcl_cout<<"Abort and pick best solution in history."<<vcl_endl;
00479 retstate= false;
00480 }
00481 }
00482 }
00483
00484 max_delta_history_.push_back(max_delta_);
00485 if (max_delta_history_.size()>NHISTORY_)
00486 {
00487 max_delta_history_.pop_front();
00488 }
00489 soln_history_.push_back(x);
00490 if (soln_history_.size()>NHISTORY_)
00491 {
00492 soln_history_.pop_front();
00493 }
00494
00495 return retstate;
00496 }
00497
00498
00499
00500
00501
00502
00503 bool mmn_lbp_solver::set_from_stream(vcl_istream &is)
00504 {
00505
00506 vcl_string s = mbl_parse_block(is);
00507 vcl_istringstream ss(s);
00508 mbl_read_props_type props = mbl_read_props_ws(ss);
00509
00510
00511
00512
00513 mbl_read_props_look_for_unused_props(
00514 "mmn_lbp_solver::set_from_stream", props, mbl_read_props_type());
00515 return true;
00516 }
00517
00518
00519
00520
00521
00522
00523 short mmn_lbp_solver::version_no() const
00524 {
00525 return 1;
00526 }
00527
00528
00529
00530
00531
00532 vcl_string mmn_lbp_solver::is_a() const
00533 {
00534 return vcl_string("mmn_lbp_solver");
00535 }
00536
00537
00538 mmn_solver* mmn_lbp_solver::clone() const
00539 {
00540 return new mmn_lbp_solver(*this);
00541 }
00542
00543
00544
00545
00546
00547 void mmn_lbp_solver::print_summary(vcl_ostream& os) const
00548 {
00549 os<<"This is a "<<is_a()<<'\t'<<"with "<<nnodes_<<" nodes"<<vcl_endl;
00550 }
00551
00552
00553
00554
00555
00556 void mmn_lbp_solver::b_write(vsl_b_ostream& bfs) const
00557 {
00558 vsl_b_write(bfs,version_no());
00559 }
00560
00561
00562
00563
00564
00565 void mmn_lbp_solver::b_read(vsl_b_istream& bfs)
00566 {
00567 if (!bfs) return;
00568 short version;
00569 vsl_b_read(bfs,version);
00570 switch (version)
00571 {
00572 case (1):
00573 break;
00574 default:
00575 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&)\n"
00576 << " Unknown version number "<< version << vcl_endl;
00577 bfs.is().clear(vcl_ios::badbit);
00578 return;
00579 }
00580 }
00581