contrib/mul/mbl/mbl_dyn_prog.cxx
Go to the documentation of this file.
00001 #include "mbl_dyn_prog.h"
00002 //:
00003 // \file
00004 
00005 #include <vcl_cstdlib.h>
00006 #include <vcl_iostream.h>
00007 #include <vcl_algorithm.h> // for std::min() & std::max()
00008 #include <vsl/vsl_indent.h>
00009 #include <vsl/vsl_binary_io.h>
00010 #include <vnl/io/vnl_io_matrix.h>
00011 
00012 //=======================================================================
00013 
00014 //=======================================================================
00015 // Dflt ctor
00016 //=======================================================================
00017 
00018 mbl_dyn_prog::mbl_dyn_prog()
00019 {
00020 }
00021 
00022 //=======================================================================
00023 // Destructor
00024 //=======================================================================
00025 
00026 mbl_dyn_prog::~mbl_dyn_prog()
00027 {
00028 }
00029 
00030 
00031 //: Construct path from links_, assuming it ends at end_state
00032 void mbl_dyn_prog::construct_path(vcl_vector<int>& x, int end_state)
00033 {
00034   unsigned int n = links_.rows();
00035   int ** b_data = links_.get_rows()-1;  // So that b_data[i] corresponds to i-th row
00036   if (x.size()!=n+1) x.resize(n+1);
00037   int *x_data = &x[0];
00038   x_data[n] = end_state;
00039   for (unsigned int i=n; i>0; --i)
00040     x_data[i-1] = b_data[i][x_data[i]];
00041 }
00042 
00043 static inline int mbl_abs(int i) { return i>=0 ? i : -i; }
00044 
00045 //=======================================================================
00046 //: Compute running costs for DP problem with costs W
00047 //  Pair cost term:  C_i(x1,x2) = c(|x1-x2|)
00048 //  Size of c indicates maximum displacement between neighbouring
00049 //  states.
00050 //  If first_state>=0 then the first is constrained to that index value
00051 void mbl_dyn_prog::running_costs(
00052                const vnl_matrix<double>& W,
00053                const vnl_vector<double>& pair_cost,
00054                int first_state)
00055 {
00056   int n = W.rows();
00057   int n_states = W.columns();
00058   const double * const* W_data = W.data_array();
00059   int max_d = pair_cost.size()-1;
00060 
00061    // On completion b(i,j) shows the best prior state (ie at i)
00062    // leading to state j at time i+1
00063   links_.resize(n-1,n_states);
00064   int ** b_data = links_.get_rows()-1;
00065      // So that b_data[i] corresponds to i-th row
00066 
00067     // ci(j) is total cost to get to current state j
00068   running_cost_ = W.get_row(0);
00069   double *ci = running_cost_.data_block();
00070   next_cost_.set_size(n_states);
00071   double *ci_new = next_cost_.data_block();
00072 
00073   for (int i=1;i<n;++i)
00074   {
00075     int *bi = b_data[i];
00076     const double *wi = W_data[i];
00077 
00078     for (int j=0;j<n_states;++j)
00079     {
00080       // Evaluate best route to get to state j at time i
00081       int k_best = 0;
00082       double cost;
00083       double wj = wi[j];
00084       double cost_best;
00085 
00086       if (i==1 && first_state>=0)
00087       {
00088         // Special case: First point pinned down to first_pt
00089         k_best = first_state;
00090         int d = mbl_abs(j-k_best);
00091         if (d>max_d) cost_best=9e9;
00092         else
00093                      cost_best = ci[k_best] + pair_cost[d]+ wj;
00094       }
00095       else
00096       {
00097         int klo = vcl_max(0,j-max_d);
00098         int khi = vcl_min(n_states-1,j+max_d);
00099         k_best=klo;
00100         cost_best = ci[klo] + pair_cost[mbl_abs(j-klo)] + wj;
00101         for (int k=klo+1;k<=khi;++k)
00102         {
00103           cost = ci[k] + pair_cost[mbl_abs(j-k)] + wj;
00104           if (cost<cost_best)
00105           {
00106             cost_best=cost;
00107             k_best = k;
00108           }
00109         }
00110       }
00111 
00112       ci_new[j] = cost_best;
00113       bi[j] = k_best;
00114     }
00115 
00116     running_cost_=next_cost_;
00117   }
00118 }
00119 
00120 //=======================================================================
00121 //: Solve the dynamic programming problem with costs W
00122 //  Pair cost term:  C_i(x1,x2) = c(|x1-x2|)
00123 //  Size of c indicates maximum displacement between neighbouring
00124 //  states.
00125 //  If first_state>=0 then the first is constrained to that index value
00126 // \retval x  Optimal path
00127 // \return Total cost of given path
00128 double mbl_dyn_prog::solve(vcl_vector<int>& x,
00129                            const vnl_matrix<double>& W,
00130                            const vnl_vector<double>& pair_cost,
00131                            int first_state)
00132 {
00133   running_costs(W,pair_cost,first_state);
00134 
00135   double *ci = running_cost_.data_block();
00136   int n_states = W.columns();
00137 
00138   // Find the best final cost
00139   int best_j = 0;
00140   double best_cost = ci[0];
00141   for (int j=1;j<n_states;++j)
00142   {
00143     if (ci[j]<best_cost) { best_j=j; best_cost=ci[j]; }
00144   }
00145 
00146   construct_path(x,best_j);
00147 
00148   return best_cost;
00149 }
00150 
00151 //: Solve the DP problem including constraint between first and last
00152 //  Cost of moving from state i to state j is move_cost[j-i]
00153 //  (move_cost[i] must be valid for i in range [1-n_states,n_states-1])
00154 // Includes cost between x[0] and x[n-1] to ensure loop closure.
00155 // \retval x  Optimal path
00156 // \return Total cost of given path
00157 double mbl_dyn_prog::solve_loop(vcl_vector<int>& x,
00158                                 const vnl_matrix<double>& W,
00159                                 const vnl_vector<double>& pair_cost)
00160 {
00161   int n_states = W.columns();
00162   int max_d = pair_cost.size()-1;
00163 
00164   double best_overall_cost=9.9e9;
00165 
00166   vcl_vector<int> x1;
00167   for (int i0=0;i0<n_states;++i0)
00168   {
00169     // Solve with constraint that first is i0
00170     running_costs(W,pair_cost,i0);
00171 
00172     double *ci = running_cost_.data_block();
00173     // Find the best final cost
00174     int klo = vcl_max(0,i0-max_d);
00175     int khi = vcl_min(n_states-1,i0+max_d);
00176     int k_best=klo;
00177     double best_cost = ci[klo] + pair_cost[mbl_abs(i0-klo)];
00178     for (int k=klo+1;k<=khi;++k)
00179     {
00180       double cost = ci[k] + pair_cost[mbl_abs(i0-k)];
00181       if (cost<best_cost) { best_cost=cost; k_best = k; }
00182     }
00183 
00184     if (best_cost<best_overall_cost)
00185     {
00186       best_overall_cost=best_cost;
00187       construct_path(x,k_best);
00188     }
00189   }
00190 
00191   return best_overall_cost;
00192 }
00193 
00194 //=======================================================================
00195 // Method: version_no
00196 //=======================================================================
00197 
00198 short mbl_dyn_prog::version_no() const
00199 {
00200   return 1;
00201 }
00202 
00203 //=======================================================================
00204 // Method: is_a
00205 //=======================================================================
00206 
00207 vcl_string mbl_dyn_prog::is_a() const
00208 {
00209   return vcl_string("mbl_dyn_prog");
00210 }
00211 
00212 //=======================================================================
00213 // Method: print
00214 //=======================================================================
00215 
00216   // required if data is present in this class
00217 void mbl_dyn_prog::print_summary(vcl_ostream& os) const
00218 {
00219 }
00220 
00221 //=======================================================================
00222 // Method: save
00223 //=======================================================================
00224 
00225   // required if data is present in this class
00226 void mbl_dyn_prog::b_write(vsl_b_ostream& bfs) const
00227 {
00228   vsl_b_write(bfs,is_a());
00229   vsl_b_write(bfs,version_no());
00230 }
00231 
00232 //=======================================================================
00233 // Method: load
00234 //=======================================================================
00235 
00236   // required if data is present in this class
00237 void mbl_dyn_prog::b_read(vsl_b_istream& bfs)
00238 {
00239   if (!bfs) return;
00240 
00241   vcl_string name;
00242   vsl_b_read(bfs,name);
00243   if (name != is_a())
00244   {
00245     vcl_cerr << "DerivedClass::load :"
00246              << " Attempted to load object of type "
00247              << name <<" into object of type " << is_a() << vcl_endl;
00248     vcl_abort();
00249   }
00250 
00251   short version;
00252   vsl_b_read(bfs,version);
00253   switch (version)
00254   {
00255     case (1):
00256       // vsl_b_read(bfs,data_); // example of data input
00257       break;
00258     default:
00259       vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, mbl_dyn_prog &)\n"
00260                << "           Unknown version number "<< version << vcl_endl;
00261       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00262       return;
00263   }
00264 }
00265 
00266 //=======================================================================
00267 // Associated function: operator<<
00268 //=======================================================================
00269 
00270 void vsl_b_write(vsl_b_ostream& bfs, const mbl_dyn_prog& b)
00271 {
00272     b.b_write(bfs);
00273 }
00274 
00275 //=======================================================================
00276 // Associated function: operator>>
00277 //=======================================================================
00278 
00279 void vsl_b_read(vsl_b_istream& bfs, mbl_dyn_prog& b)
00280 {
00281     b.b_read(bfs);
00282 }
00283 
00284 //=======================================================================
00285 // Associated function: operator<<
00286 //=======================================================================
00287 
00288 vcl_ostream& operator<<(vcl_ostream& os,const mbl_dyn_prog& b)
00289 {
00290   os << b.is_a() << ": ";
00291   vsl_indent_inc(os);
00292   b.print_summary(os);
00293   vsl_indent_dec(os);
00294   return os;
00295 }