00001 #ifndef mmn_diffusion_solver_h_ 00002 #define mmn_diffusion_solver_h_ 00003 //: 00004 // \file 00005 // \brief Run diffusion algorithm over the graph 00006 // \author Martin Roberts 00007 00008 #include <vcl_vector.h> 00009 #include <vcl_map.h> 00010 #include <vnl/vnl_vector.h> 00011 #include <vnl/vnl_matrix.h> 00012 #include <mmn/mmn_arc.h> 00013 #include <mmn/mmn_graph_rep1.h> 00014 00015 //: Run diffusion algorithm to solve max sum problem 00016 // See T Werner. A Linear Programming Approach to Max-sum problem: A review; 00017 // IEEE Trans on Pattern Recog & Machine Intell, July 2007 00018 // Try and solve the max-sum problem by performing node-pencil averaging over the graph 00019 // I.e. transform to equivalent problem by adding "potentials" to nodes and subtracting them 00020 // from arcs. This is done to equalise node costs and the cost of the maximal connecting arcs 00021 // If this converges the solution is to take the maximal nodes (which will then be arc-consistent). 00022 00023 class mmn_diffusion_solver 00024 { 00025 private: 00026 //: in below the map is indexed by the neighbour's node id 00027 00028 //: Inner vector indexed by source node state ID, map by neighbour node (t') 00029 typedef vcl_map<unsigned,vnl_vector<double > > potential_set_t; 00030 00031 //: Matrix referenced by [source node state ID][target node state ID] 00032 // Map ID is target node ID 00033 typedef vcl_map<unsigned, vnl_matrix<double > > neigh_arc_cost_t; 00034 00035 //:Store in graph form (so each node's neighbours are conveniently to hand) 00036 mmn_graph_rep1 graph_; 00037 00038 //: The arcs from which graph was generated 00039 vcl_vector<mmn_arc> arcs_; 00040 00041 //: Total number of nodes 00042 unsigned nnodes_; 00043 00044 //: Workspace for costs of each arc 00045 vcl_vector<neigh_arc_cost_t > arc_costs_; 00046 00047 //: Workspace for transformed costs of each arc 00048 vcl_vector<neigh_arc_cost_t > arc_costs_phi_; 00049 00050 00051 //: All the potentials at previous iteration (vector index is source node) 00052 vcl_vector<potential_set_t > phi_; 00053 00054 //: Update potentials calculated during this iteration (vector index is source node) 00055 vcl_vector<potential_set_t > phi_upd_; 00056 00057 //: Node costs (outer vector is node ID, inner vnl_vector is by state value) 00058 vcl_vector<vnl_vector<double> > node_costs_; 00059 00060 //: Node costs after phi transform (outer vector is node ID, inner vnl_vector is by state value) 00061 vcl_vector<vnl_vector<double> > node_costs_phi_; 00062 00063 //: Workspace for adjustment to potential 00064 vcl_vector<vcl_map<unsigned,vnl_vector<double > > > u_; 00065 00066 //: Current iteration count 00067 unsigned count_; 00068 00069 //: Max change in any potential value over this iteration 00070 double max_delta_; 00071 00072 //: max number of iterations allowed 00073 unsigned max_iterations_; 00074 00075 //: min iterations allowed before additional convergence checks 00076 unsigned min_iterations_; 00077 00078 //: Convergence criterion on max_delta_ 00079 double epsilon_; 00080 00081 //:verbose debug output 00082 bool verbose_; 00083 00084 //: Previous solution value at arc consistency check 00085 double soln_val_prev_; 00086 00087 //: Count of number of times solution value was unchanged 00088 unsigned nConverging_; 00089 00090 //: Max count of nConverging_ 00091 static unsigned gNCONVERGED; 00092 00093 //: Period at which arc consistency of solution is checked for 00094 static unsigned gACS_CHECK_PERIOD; 00095 00096 //: Check if we carry on 00097 bool continue_diffusion(); 00098 00099 //: Update all messages from input node to its neighbours 00100 void update_potentials_to_neighbours(unsigned inode, 00101 const vnl_vector<double>& node_cost); 00102 00103 //: Update all node and arc costs (equivalent transform) given phi (potentials) 00104 void transform_costs(); 00105 00106 //: Update node and arc costs (equivalent transform) given phi (potentials) for given node 00107 void transform_costs(unsigned inode); 00108 00109 //: Find maximal nodes and arcs and check if arc consistent 00110 bool arc_consistent_solution(vcl_vector<unsigned >& x); 00111 00112 //: Reset iteration counters 00113 void init(); 00114 //: Calculate final sum of node and arc values 00115 double solution_cost(vcl_vector<unsigned>& x); 00116 00117 public: 00118 //: Default constructor 00119 mmn_diffusion_solver(); 00120 00121 //: Construct with arcs 00122 mmn_diffusion_solver(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs); 00123 00124 //: Input the arcs that define the graph 00125 void set_arcs(unsigned num_nodes,const vcl_vector<mmn_arc>& arcs); 00126 00127 //: Find values for each node with minimise the total cost 00128 // \param node_cost: node_cost[i][j] is cost of selecting value j for node i 00129 // \param pair_cost: pair_cost[a](i,j) is cost of selecting values (i,j) for nodes at end of arc a. 00130 // \param x: On exit, x[i] gives choice for node i 00131 // NOTE: If arc a connects nodes v1,v2, the associated pair_cost is ordered 00132 // with the node with the lowest index being the first parameter. Thus if 00133 // v1 has value i1, v2 has value i2, then the cost of this choice is 00134 // (v1<v2?pair_cost(i1,i2):pair_cost(i2,i1)) 00135 // 00136 // Returns the minimum cost. Note that internally we deal with a maximisation 00137 // after negating the input costs, which are assumed to represent -log probabilities 00138 // In the return the boolean returns whether the algorithm was successful in converging 00139 // to an arc-consistent solution, and the double is the cost (negative minimum, i.e. -internal max) 00140 // Even if the solution is not arc-consistent a solution is still returned given by the local node 00141 // first maxima, but this may not then be optimal. 00142 vcl_pair<bool,double> operator()(const vcl_vector<vnl_vector<double> >& node_cost, 00143 const vcl_vector<vnl_matrix<double> >& arc_cost, 00144 vcl_vector<unsigned>& x); 00145 00146 //: final iteration count 00147 unsigned count() const {return count_;} 00148 00149 //: Produce shed loads of debug output 00150 void set_verbose(bool verbose) {verbose_=verbose;} 00151 }; 00152 00153 #endif // mmn_diffusion_solver_h_