contrib/brl/bseg/sdet/sdet_mrf_bp.h
Go to the documentation of this file.
00001 #ifndef sdet_mrf_bp_h_
00002 #define sdet_mrf_bp_h_
00003 //:
00004 // \file
00005 // \brief  A class for representing a set of sites in an MRF
00006 // \author J.L. Mundy
00007 // \date   26 March 2011
00008 //
00009 //  Each MRF site has a 4-neighborhood,
00010 //     u
00011 //   l x  r
00012 //     d
00013 // with index (u, l, r, d) ==> (0, 1, 2, 3)
00014 //
00015 #include <sdet/sdet_mrf_site_bp_sptr.h>
00016 #include <vil/vil_image_resource.h>
00017 #include <vil/vil_image_view.h>
00018 #include <vbl/vbl_array_2d.h>
00019 #include <vbl/vbl_ref_count.h>
00020 #include <vcl_vector.h>
00021 
00022 class sdet_mrf_bp : public vbl_ref_count
00023 {
00024  public:
00025   //: simple constructor for testing
00026   sdet_mrf_bp(unsigned ni, unsigned nj, unsigned n_labels);
00027 
00028   //: constructor with observed labels provided by resc
00029   // In this case the data cost is just $\lambda \times (fp-x)^2$
00030   sdet_mrf_bp(vil_image_resource_sptr obs_labels, unsigned n_labels,
00031               float discontinuity_cost, float truncation_cost,
00032               float kappa, float lambda);
00033 
00034   //: constructor with observed labels view
00035   sdet_mrf_bp(vil_image_view<float> const& obs_labels, unsigned n_labels,
00036               float discontinuity_cost, float truncation_cost,
00037               float kappa, float lambda);
00038 
00039   //: constructor with observed labels and variance resources
00040   sdet_mrf_bp(vil_image_resource_sptr obs_labels,
00041               vil_image_resource_sptr var,  unsigned n_labels,
00042               float discontinuity_cost, float truncation_cost,
00043               float kappa, float lambda);
00044 
00045   //: constructor with observed label and variance views.
00046   sdet_mrf_bp(vil_image_view<float> const& obs_labels,
00047               vil_image_view<float> const& var,  unsigned n_labels,
00048               float discontinuity_cost, float truncation_cost,
00049               float kappa, float lambda);
00050 
00051   //: limit cost at a discontinuity
00052   void set_discontinuity_cost(float discontinuity_cost)
00053     { discontinuity_cost_ = discontinuity_cost; }
00054 
00055   //: truncation of data cost
00056   void set_truncation_cost(float truncation_cost)
00057     { truncation_cost_ =truncation_cost; }
00058 
00059   //: contribution of data to cost
00060   void set_lambda(float lambda) { lambda_ = lambda; }
00061 
00062   //: the contribution of neighbor label difference to cost
00063   void set_kappa(float kappa) { kappa_ = kappa; }
00064 
00065   //: transform from image coordinates to node indices
00066   unsigned image_to_index(unsigned i, unsigned j) { return i + ni_*j; }
00067 
00068   //: transform from node indices to image coordinates
00069   void index_to_image(unsigned p, unsigned& i, unsigned& j)
00070     { j = p/ni_; i = p-j*ni_; }
00071 
00072   //: mrf dimension (columns)
00073   unsigned ni() const { return ni_; }
00074   //: mrf dimension (rows)
00075   unsigned nj() const { return nj_; }
00076 
00077   //: retrieve a site by image index
00078   sdet_mrf_site_bp_sptr site(unsigned i, unsigned j) { return sites_[j][i]; }
00079 
00080   //: retrieve a site by linear index
00081   sdet_mrf_site_bp_sptr site(unsigned p) { unsigned i, j; index_to_image(p,i,j); return sites_[j][i]; }
00082 
00083   //: get the contents of a prior message buffer
00084   vcl_vector<float> prior_message(unsigned i, unsigned j, unsigned n);
00085 
00086   //: set the contents of a prior message buffer
00087   void set_prior_message(unsigned i, unsigned j, unsigned n,
00088                          vcl_vector<float> const& msg);
00089 
00090   //:
00091   // all sites send messages to current buffer of neighbors,
00092   // using an O(n_labels) algorithm based on the lower envelope
00093   void send_messages_optimized();
00094 
00095   //: clear messages from all sites
00096   void clear();
00097 
00098   //: all sites print the contents of the prior buffers
00099   void print_prior_messages();
00100 
00101   //: all sites print their belief vector
00102   void print_belief_vectors();
00103 
00104   //: output
00105   vil_image_resource_sptr belief_image();
00106 
00107 
00108  protected:
00109   //members
00110   unsigned ni_;
00111   unsigned nj_;
00112   unsigned n_labels_;
00113   float discontinuity_cost_;
00114   float truncation_cost_;
00115   float kappa_;
00116   float lambda_;
00117   float min_;
00118   float max_;
00119   //the array of sites
00120   vbl_array_2d<sdet_mrf_site_bp_sptr> sites_;
00121   sdet_mrf_bp();
00122 };
00123 
00124 // === public functions ===
00125 
00126 //: computes the lower envelope of a message array with $V(fp, fq) = w|fp-fq|$
00127 //  Computes in place
00128 void lower_envelope_linear(float w, vcl_vector<float>& msg);
00129 
00130 //: computes the lower envelope of a message array with $V(fp, fq) = w(fp-fq)^2$
00131 //  Used in current implementation
00132 vcl_vector<float> lower_envelope_quadratic(float w,
00133                                            vcl_vector<float> const& h);
00134 
00135 #include <sdet/sdet_mrf_bp_sptr.h>
00136 #endif // sdet_mrf_bp_h_