contrib/mul/mmn/mmn_lbp_solver.h
Go to the documentation of this file.
00001 #ifndef mmn_lbp_solver_h_
00002 #define mmn_lbp_solver_h_
00003 //:
00004 // \file
00005 // \brief Run loopy belief propagation over the graph
00006 // \author Martin Roberts
00007 
00008 #include <vcl_vector.h>
00009 #include <vcl_map.h>
00010 #include <vcl_deque.h>
00011 #include <vnl/vnl_vector.h>
00012 #include <vnl/vnl_matrix.h>
00013 #include <mmn/mmn_arc.h>
00014 #include <mmn/mmn_graph_rep1.h>
00015 #include <mmn/mmn_solver.h>
00016 #include <vcl_iosfwd.h>
00017 
00018 //: Run loopy belief to estimate overall marginal probabilities of all node states
00019 //  Then use converged LBP messages to also estimate overall most likely configuration
00020 //
00021 // Can use this for non-tree graphs, but convergence to optimum is not absolutely guaranteed
00022 // Should converge if there is at most one loop in the graph
00023 // The input graph is converted to form mmn_graph_rep1 from the input arcs
00024 
00025 class mmn_lbp_solver: public mmn_solver
00026 {
00027  public:
00028     //: Message update mode type (all in parallel, or randomised node order with immediate effect}
00029     enum msg_update_t {eALL_PARALLEL,eRANDOM_SERIAL};
00030  private:
00031     //: in below the map is indexed by the neighbour's node id
00032 
00033     //: Inner vector indexed by target node state ID
00034     typedef vcl_map<unsigned,vnl_vector<double > > message_set_t;
00035 
00036     //: Matrix referenced by [source node state ID][target node state ID]
00037     // Map ID is target node ID
00038     typedef vcl_map<unsigned, vnl_matrix<double > > neigh_arc_cost_t;
00039 
00040     //:Store in graph form (so each node's neighbours are conveniently to hand)
00041     mmn_graph_rep1 graph_;
00042 
00043     //: The arcs from which the graph was generated
00044     vcl_vector<mmn_arc> arcs_;
00045 
00046     //: Total number of nodes
00047     unsigned nnodes_;
00048 
00049     //: Workspace for costs of each arc
00050     vcl_vector<neigh_arc_cost_t > arc_costs_;
00051 
00052     //: All the messages at previous iteration (vector index is source node)
00053     vcl_vector<message_set_t > messages_;
00054     //: Update messages calculated during this iteration (vector index is source node)
00055     vcl_vector<message_set_t > messages_upd_;
00056 
00057     //: Node costs (outer vector is node ID, inner vnl_vector is by state value)
00058     vcl_vector<vnl_vector<double> > node_costs_;
00059 
00060     //: belief prob for each state of each node
00061     // Assumes input node costs are well-normalised for these to be proper probabilities
00062     vcl_vector<vnl_vector<double> > belief_;
00063 
00064     //: previous N solutions (used to trap cycling)
00065     vcl_deque<vcl_vector<unsigned  > > soln_history_;
00066 
00067     //: previous max_delta values(used to check still descending)
00068     vcl_deque<double  > max_delta_history_;
00069 
00070     //: Current iteration count
00071     unsigned count_;
00072 
00073     //: Max change in any message value over this iteration
00074     double max_delta_;
00075 
00076     //: max number of iterations allowed
00077     unsigned max_iterations_;
00078 
00079     //: min number of iterations before checking for solution looping (cycling)
00080     unsigned min_simple_iterations_;
00081 
00082     //: Convergence criterion on max_delta_
00083     double epsilon_;
00084 
00085     //: count of number of times a solution in history is revisited
00086     unsigned nrevisits_;
00087 
00088     //: cycle condition detected
00089     bool isCycling_;
00090 
00091     //: Number of times cycling has been detected
00092     unsigned cycle_detection_count_;
00093 
00094     //: message update smoothing constant (used if cycling detected)
00095     double alpha_;
00096 
00097     //: should message update be smoothed during cycling
00098     bool smooth_on_cycling_;
00099 
00100     //; Maximum number of allowed cycle detections
00101     //NOTE only used if smooth_on_cycling_ is true
00102     //Otherwise we give up after the first cycle is detected
00103     unsigned max_cycle_detection_count_;
00104 
00105     //: solution value when cycling first detected
00106     double zbest_on_cycle_detection_;
00107 
00108     //:verbose debug output
00109     bool verbose_;
00110 
00111     msg_update_t msg_upd_mode_;
00112 
00113     //: Magic numbers for cycle detection
00114     static const unsigned NHISTORY_;
00115     static const unsigned NCYCLE_DETECT_;
00116 
00117     //: Check if we carry on
00118     bool continue_propagation(vcl_vector<unsigned>& x);
00119 
00120     //: Update all messages from input node to its neighbours
00121     void update_messages_to_neighbours(unsigned inode,
00122                                        const vnl_vector<double>& node_cost);
00123 
00124     //: Renormalise messages (assume they represent log probabilities) so SUM(exp) over target states is 1
00125     void renormalise_log(vnl_vector<double >& logMessageVec);
00126 
00127     //: Reset iteration counters
00128     void init();
00129     //: Calculate final sum of node and arc values
00130     double solution_cost(vcl_vector<unsigned>& x);
00131 
00132     double best_solution_cost_in_history(vcl_vector<unsigned>& x);
00133 
00134     //: update beliefs and calculate changes therein
00135     void calculate_beliefs(vcl_vector<unsigned>& x);
00136  public:
00137     //: Default constructor
00138     mmn_lbp_solver();
00139 
00140     //: Construct with arcs
00141     mmn_lbp_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs);
00142 
00143     //: Input the arcs that define the graph
00144     virtual void set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs);
00145 
00146     //: Find values for each node with minimise the total cost
00147     //  \param node_cost: node_cost[i][j] is cost of selecting value j for node i
00148     //  \param pair_cost: pair_cost[a](i,j) is cost of selecting values (i,j) for nodes at end of arc a.
00149     //  \param x: On exit, x[i] gives choice for node i
00150     // NOTE: If arc a connects nodes v1,v2, the associated pair_cost is ordered
00151     // with the node with the lowest index being the first parameter.  Thus if
00152     // v1 has value i1, v2 has value i2, then the cost of this choice is
00153     // (v1<v2?pair_cost(i1,i2):pair_cost(i2,i1))
00154     //
00155     // Returns the minimum cost. Note that internally we deal with a maximisation
00156     // after negating the input costs, which are assumed to represent -log probabilities
00157     // If the states of a node in node_cost are not well normalised as a probability
00158     // the algorithm should still work in some sense
00159     // but the meaning of the belief_ objects is not really then well-defined.
00160     // As it is marginal "belief" that is maximised, inputting non-normalised data may not give quite the
00161     // expected answer - there may be some biases, in effect implicit weightings to particular nodes
00162     double operator()(const vcl_vector<vnl_vector<double> >& node_cost,
00163                       const vcl_vector<vnl_matrix<double> >& arc_cost,
00164                       vcl_vector<unsigned>& x);
00165 
00166     //: return the beliefs, i.e. the marginal probabilities of each node's states
00167     //
00168     virtual double solve(
00169                  const vcl_vector<vnl_vector<double> >& node_cost,
00170                  const vcl_vector<vnl_matrix<double> >& pair_cost,
00171                  vcl_vector<unsigned>& x);
00172 
00173     const vcl_vector<vnl_vector<double>  >&  belief() const {return belief_;}
00174 
00175     //: final iteration count
00176     unsigned count() const {return count_;}
00177 
00178     //: Set true if want to alpha smooth message updates when cycling detected
00179     // This may break the cycling condition
00180     void set_smooth_on_cycling(bool bOn) {smooth_on_cycling_=bOn;}
00181 
00182     void set_max_cycle_detection_count_(unsigned max_cycle_detection_count) {max_cycle_detection_count_=max_cycle_detection_count;}
00183 
00184     void set_verbose(bool verbose) {verbose_=verbose;}
00185 
00186     //: Set message update mode (parallel or randomised serial}
00187     void set_msg_upd_mode(msg_update_t msg_upd_mode) {msg_upd_mode_ = msg_upd_mode;}
00188 
00189     //: Initialise from a text stream
00190     virtual bool set_from_stream(vcl_istream &is);
00191 
00192     //: Version number for I/O
00193     short version_no() const;
00194 
00195     //: Name of the class
00196     virtual vcl_string is_a() const;
00197 
00198     //: Create a copy on the heap and return base class pointer
00199     virtual mmn_solver* clone() const;
00200 
00201     //: Print class to os
00202     virtual void print_summary(vcl_ostream& os) const;
00203 
00204     //: Save class to binary file stream
00205     virtual void b_write(vsl_b_ostream& bfs) const;
00206 
00207     //: Load class from binary file stream
00208     virtual void b_read(vsl_b_istream& bfs);
00209 };
00210 
00211 #endif // mmn_lbp_solver_h_