00001 #ifndef mmn_lbp_solver_h_ 00002 #define mmn_lbp_solver_h_ 00003 //: 00004 // \file 00005 // \brief Run loopy belief propagation over the graph 00006 // \author Martin Roberts 00007 00008 #include <vcl_vector.h> 00009 #include <vcl_map.h> 00010 #include <vcl_deque.h> 00011 #include <vnl/vnl_vector.h> 00012 #include <vnl/vnl_matrix.h> 00013 #include <mmn/mmn_arc.h> 00014 #include <mmn/mmn_graph_rep1.h> 00015 #include <mmn/mmn_solver.h> 00016 #include <vcl_iosfwd.h> 00017 00018 //: Run loopy belief to estimate overall marginal probabilities of all node states 00019 // Then use converged LBP messages to also estimate overall most likely configuration 00020 // 00021 // Can use this for non-tree graphs, but convergence to optimum is not absolutely guaranteed 00022 // Should converge if there is at most one loop in the graph 00023 // The input graph is converted to form mmn_graph_rep1 from the input arcs 00024 00025 class mmn_lbp_solver: public mmn_solver 00026 { 00027 public: 00028 //: Message update mode type (all in parallel, or randomised node order with immediate effect} 00029 enum msg_update_t {eALL_PARALLEL,eRANDOM_SERIAL}; 00030 private: 00031 //: in below the map is indexed by the neighbour's node id 00032 00033 //: Inner vector indexed by target node state ID 00034 typedef vcl_map<unsigned,vnl_vector<double > > message_set_t; 00035 00036 //: Matrix referenced by [source node state ID][target node state ID] 00037 // Map ID is target node ID 00038 typedef vcl_map<unsigned, vnl_matrix<double > > neigh_arc_cost_t; 00039 00040 //:Store in graph form (so each node's neighbours are conveniently to hand) 00041 mmn_graph_rep1 graph_; 00042 00043 //: The arcs from which the graph was generated 00044 vcl_vector<mmn_arc> arcs_; 00045 00046 //: Total number of nodes 00047 unsigned nnodes_; 00048 00049 //: Workspace for costs of each arc 00050 vcl_vector<neigh_arc_cost_t > arc_costs_; 00051 00052 //: All the messages at previous iteration (vector index is source node) 00053 vcl_vector<message_set_t > messages_; 00054 //: Update messages calculated during this iteration (vector index is source node) 00055 vcl_vector<message_set_t > messages_upd_; 00056 00057 //: Node costs (outer vector is node ID, inner vnl_vector is by state value) 00058 vcl_vector<vnl_vector<double> > node_costs_; 00059 00060 //: belief prob for each state of each node 00061 // Assumes input node costs are well-normalised for these to be proper probabilities 00062 vcl_vector<vnl_vector<double> > belief_; 00063 00064 //: previous N solutions (used to trap cycling) 00065 vcl_deque<vcl_vector<unsigned > > soln_history_; 00066 00067 //: previous max_delta values(used to check still descending) 00068 vcl_deque<double > max_delta_history_; 00069 00070 //: Current iteration count 00071 unsigned count_; 00072 00073 //: Max change in any message value over this iteration 00074 double max_delta_; 00075 00076 //: max number of iterations allowed 00077 unsigned max_iterations_; 00078 00079 //: min number of iterations before checking for solution looping (cycling) 00080 unsigned min_simple_iterations_; 00081 00082 //: Convergence criterion on max_delta_ 00083 double epsilon_; 00084 00085 //: count of number of times a solution in history is revisited 00086 unsigned nrevisits_; 00087 00088 //: cycle condition detected 00089 bool isCycling_; 00090 00091 //: Number of times cycling has been detected 00092 unsigned cycle_detection_count_; 00093 00094 //: message update smoothing constant (used if cycling detected) 00095 double alpha_; 00096 00097 //: should message update be smoothed during cycling 00098 bool smooth_on_cycling_; 00099 00100 //; Maximum number of allowed cycle detections 00101 //NOTE only used if smooth_on_cycling_ is true 00102 //Otherwise we give up after the first cycle is detected 00103 unsigned max_cycle_detection_count_; 00104 00105 //: solution value when cycling first detected 00106 double zbest_on_cycle_detection_; 00107 00108 //:verbose debug output 00109 bool verbose_; 00110 00111 msg_update_t msg_upd_mode_; 00112 00113 //: Magic numbers for cycle detection 00114 static const unsigned NHISTORY_; 00115 static const unsigned NCYCLE_DETECT_; 00116 00117 //: Check if we carry on 00118 bool continue_propagation(vcl_vector<unsigned>& x); 00119 00120 //: Update all messages from input node to its neighbours 00121 void update_messages_to_neighbours(unsigned inode, 00122 const vnl_vector<double>& node_cost); 00123 00124 //: Renormalise messages (assume they represent log probabilities) so SUM(exp) over target states is 1 00125 void renormalise_log(vnl_vector<double >& logMessageVec); 00126 00127 //: Reset iteration counters 00128 void init(); 00129 //: Calculate final sum of node and arc values 00130 double solution_cost(vcl_vector<unsigned>& x); 00131 00132 double best_solution_cost_in_history(vcl_vector<unsigned>& x); 00133 00134 //: update beliefs and calculate changes therein 00135 void calculate_beliefs(vcl_vector<unsigned>& x); 00136 public: 00137 //: Default constructor 00138 mmn_lbp_solver(); 00139 00140 //: Construct with arcs 00141 mmn_lbp_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs); 00142 00143 //: Input the arcs that define the graph 00144 virtual void set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs); 00145 00146 //: Find values for each node with minimise the total cost 00147 // \param node_cost: node_cost[i][j] is cost of selecting value j for node i 00148 // \param pair_cost: pair_cost[a](i,j) is cost of selecting values (i,j) for nodes at end of arc a. 00149 // \param x: On exit, x[i] gives choice for node i 00150 // NOTE: If arc a connects nodes v1,v2, the associated pair_cost is ordered 00151 // with the node with the lowest index being the first parameter. Thus if 00152 // v1 has value i1, v2 has value i2, then the cost of this choice is 00153 // (v1<v2?pair_cost(i1,i2):pair_cost(i2,i1)) 00154 // 00155 // Returns the minimum cost. Note that internally we deal with a maximisation 00156 // after negating the input costs, which are assumed to represent -log probabilities 00157 // If the states of a node in node_cost are not well normalised as a probability 00158 // the algorithm should still work in some sense 00159 // but the meaning of the belief_ objects is not really then well-defined. 00160 // As it is marginal "belief" that is maximised, inputting non-normalised data may not give quite the 00161 // expected answer - there may be some biases, in effect implicit weightings to particular nodes 00162 double operator()(const vcl_vector<vnl_vector<double> >& node_cost, 00163 const vcl_vector<vnl_matrix<double> >& arc_cost, 00164 vcl_vector<unsigned>& x); 00165 00166 //: return the beliefs, i.e. the marginal probabilities of each node's states 00167 // 00168 virtual double solve( 00169 const vcl_vector<vnl_vector<double> >& node_cost, 00170 const vcl_vector<vnl_matrix<double> >& pair_cost, 00171 vcl_vector<unsigned>& x); 00172 00173 const vcl_vector<vnl_vector<double> >& belief() const {return belief_;} 00174 00175 //: final iteration count 00176 unsigned count() const {return count_;} 00177 00178 //: Set true if want to alpha smooth message updates when cycling detected 00179 // This may break the cycling condition 00180 void set_smooth_on_cycling(bool bOn) {smooth_on_cycling_=bOn;} 00181 00182 void set_max_cycle_detection_count_(unsigned max_cycle_detection_count) {max_cycle_detection_count_=max_cycle_detection_count;} 00183 00184 void set_verbose(bool verbose) {verbose_=verbose;} 00185 00186 //: Set message update mode (parallel or randomised serial} 00187 void set_msg_upd_mode(msg_update_t msg_upd_mode) {msg_upd_mode_ = msg_upd_mode;} 00188 00189 //: Initialise from a text stream 00190 virtual bool set_from_stream(vcl_istream &is); 00191 00192 //: Version number for I/O 00193 short version_no() const; 00194 00195 //: Name of the class 00196 virtual vcl_string is_a() const; 00197 00198 //: Create a copy on the heap and return base class pointer 00199 virtual mmn_solver* clone() const; 00200 00201 //: Print class to os 00202 virtual void print_summary(vcl_ostream& os) const; 00203 00204 //: Save class to binary file stream 00205 virtual void b_write(vsl_b_ostream& bfs) const; 00206 00207 //: Load class from binary file stream 00208 virtual void b_read(vsl_b_istream& bfs); 00209 }; 00210 00211 #endif // mmn_lbp_solver_h_