contrib/brl/bseg/sdet/sdet_mrf_bp.cxx
Go to the documentation of this file.
00001 #include "sdet_mrf_bp.h"
00002 #include <sdet/sdet_mrf_site_bp.h>
00003 #include <vil/vil_convert.h>
00004 #include <vil/vil_math.h>
00005 #include <vil/vil_new.h>
00006 #include <vnl/vnl_numeric_traits.h>
00007 #include <vcl_cassert.h>
00008 //
00009 // index for relative image position of neighbors
00010 //      u
00011 //    l   r
00012 //      d
00013 //                       u   l  r  d
00014 static const int di[4]={ 0, -1, 1, 0};
00015 static const int dj[4]={-1,  0, 0, 1};
00016 
00017 void lower_envelope_linear(float w, vcl_vector<float>& msg)
00018 {
00019   unsigned nlabels = msg.size();
00020   // pass 1
00021   for (unsigned fq=1; fq<nlabels; ++fq) {
00022     float mfq = msg[fq];
00023     float mfqm1 = msg[fq-1] + w;
00024     if (mfq<mfqm1)
00025       msg[fq]=mfq;
00026     else
00027       msg[fq]=mfqm1;
00028   }
00029   // pass 2
00030   for (int fq=nlabels-2; fq>=0; --fq) {
00031     float mfq = msg[fq];
00032     float mfqp1 = msg[fq+1]+w;
00033     if (mfq<mfqp1)
00034       msg[fq]=mfq;
00035     else
00036       msg[fq]=mfqp1;
00037   }
00038 }
00039 
00040 vcl_vector<float> lower_envelope_quadratic(float w,
00041                                            vcl_vector<float> const& h)
00042 {
00043   int nlabels = h.size();
00044   vcl_vector<float> env_out(nlabels);
00045   vcl_vector<int> v(nlabels);
00046   vcl_vector<float> z(nlabels+1);
00047   int k = 0;
00048   v[0] = 0;
00049   z[0] = -vnl_numeric_traits<float>::maxval;
00050   z[1] = -z[0];
00051 
00052   for (int fq = 1; fq <= nlabels-1; ++fq) {
00053     float s  = ((h[fq]+w*fq*fq)-(h[v[k]]+w*v[k]*v[k])) / (2*w*(fq-v[k]));
00054     while (s <= z[k]) {
00055       k--;
00056       s  = ((h[fq]+fq*fq)-(h[v[k]]+v[k]*v[k])) / (2*(fq-v[k]));
00057     }
00058     ++k;
00059     v[k] = fq;
00060     z[k] = s;
00061     z[k+1] = vnl_numeric_traits<float>::maxval;
00062   }
00063   k = 0;
00064   for (int fq = 0; fq <= nlabels-1; ++fq) {
00065     while (z[k+1] < fq)
00066       ++k;
00067     env_out[fq] = w*(fq-v[k])*(fq-v[k]) + h[v[k]];
00068   }
00069   return env_out;
00070 }
00071 
00072 sdet_mrf_bp::sdet_mrf_bp(unsigned ni, unsigned nj,
00073                          unsigned n_labels)
00074   : ni_(ni), nj_(nj), n_labels_(n_labels), discontinuity_cost_(1.0f),
00075     truncation_cost_(1.0f), kappa_(1.0f), lambda_(1.0f), min_(0.0f), max_(0.0f)
00076 {
00077   sites_.resize(nj_, ni_);
00078   for (unsigned j = 0; j<nj_; ++j)
00079     for (unsigned i = 0; i<ni_; ++i)
00080       sites_[j][i]=new sdet_mrf_site_bp(n_labels_,lambda_, truncation_cost_);
00081 }
00082 
00083 sdet_mrf_bp::sdet_mrf_bp(vil_image_resource_sptr obs_labels, unsigned n_labels,
00084                          float discontinuity_cost, float truncation_cost,
00085                          float kappa, float lambda)
00086   : ni_(0), nj_(0), n_labels_(n_labels),
00087     discontinuity_cost_(discontinuity_cost),
00088     truncation_cost_(truncation_cost), kappa_(kappa),
00089     lambda_(lambda), min_(0.0f), max_(0.0f)
00090 {
00091   if (!obs_labels) return;
00092   ni_=obs_labels->ni();   nj_=obs_labels->nj();
00093   vil_image_view_base_sptr temp = obs_labels->get_view();
00094   vil_image_view<float> view = *vil_convert_cast(float(), temp);
00095   vil_math_value_range(view, min_, max_);
00096   if (min_ >= max_) return;
00097   sites_.resize(nj_, ni_);
00098   float scale = (n_labels-1)/(max_-min_);
00099   for (unsigned j = 0; j<nj_; ++j)
00100     for (unsigned i = 0; i<ni_; ++i) {
00101       sdet_mrf_site_bp_sptr s =
00102         new sdet_mrf_site_bp(n_labels_, lambda_, truncation_cost_);
00103       s->set_label(scale*(view(i,j)-min_));
00104       sites_[j][i]=s;
00105     }
00106 }
00107 
00108 sdet_mrf_bp::sdet_mrf_bp(vil_image_view<float> const& obs_labels,
00109                          unsigned n_labels, float discontinuity_cost,
00110                          float truncation_cost, float kappa, float lambda)
00111   : n_labels_(n_labels), discontinuity_cost_(discontinuity_cost),
00112     truncation_cost_(truncation_cost), kappa_(kappa), lambda_(lambda),
00113     min_(0.0f), max_(0.0f)
00114 {
00115   if (!obs_labels) return;
00116   ni_=obs_labels.ni();   nj_=obs_labels.nj();
00117   vil_math_value_range(obs_labels, min_, max_);
00118   if (min_ >= max_) return;
00119   sites_.resize(nj_, ni_);
00120   float scale = (n_labels-1)/(max_-min_);
00121   for (unsigned j = 0; j<nj_; ++j)
00122     for (unsigned i = 0; i<ni_; ++i) {
00123       sdet_mrf_site_bp_sptr s =
00124         new sdet_mrf_site_bp(n_labels_, lambda_, truncation_cost_);
00125       s->set_label(scale*(obs_labels(i,j)-min_));
00126       sites_[j][i]=s;
00127     }
00128 }
00129 
00130 sdet_mrf_bp::sdet_mrf_bp(vil_image_view<float> const& obs_labels,
00131                          vil_image_view<float> const& var,  unsigned n_labels,
00132                          float discontinuity_cost, float truncation_cost,
00133                          float kappa, float lambda)
00134   : n_labels_(n_labels), discontinuity_cost_(discontinuity_cost),
00135     truncation_cost_(truncation_cost), kappa_(kappa),
00136     lambda_(lambda), min_(0.0f), max_(0.0f)
00137 {
00138   ni_=obs_labels.ni();   nj_=obs_labels.nj();
00139   if (!ni_||!nj_) return;
00140   vil_math_value_range(obs_labels, min_, max_);
00141   if (min_ >= max_) return;
00142   sites_.resize(nj_, ni_);
00143   float scale = (n_labels-1)/(max_-min_);
00144   int ni = static_cast<int>(ni_), nj = static_cast<int>(nj_);
00145   for (int j = 0; j<nj; ++j)
00146     for (int i = 0; i<ni; ++i) {
00147       float vlamb = lambda_;
00148       if (var(i,j)>0.0f)
00149         vlamb = lambda_/var(i,j);
00150       sdet_mrf_site_bp_sptr s =
00151         new sdet_mrf_site_bp(n_labels_,vlamb, truncation_cost_);
00152       s->set_label(scale*(obs_labels(i,j)-min_));
00153       sites_[j][i]=s;
00154     }
00155 }
00156 
00157 sdet_mrf_bp::sdet_mrf_bp(vil_image_resource_sptr  obs_labels,
00158                          vil_image_resource_sptr  var,
00159                          unsigned n_labels, float discontinuity_cost,
00160                          float truncation_cost, float kappa, float lambda)
00161   :  n_labels_(n_labels),discontinuity_cost_(discontinuity_cost),
00162      truncation_cost_(truncation_cost), kappa_(kappa), lambda_(lambda),
00163      min_(0.0f), max_(0.0f)
00164 {
00165   if (!obs_labels) return;
00166   ni_=obs_labels->ni();   nj_=obs_labels->nj();
00167   vil_image_view_base_sptr temp = obs_labels->get_view();
00168   vil_image_view<float> view = *vil_convert_cast(float(), temp);
00169   vil_math_value_range(view, min_, max_);
00170   if (min_ >= max_) return;
00171   sites_.resize(nj_, ni_);
00172   float scale = (n_labels-1)/(max_-min_);
00173   vil_image_view_base_sptr tempv = var->get_view();
00174   vil_image_view<float> var_view = *vil_convert_cast(float(), temp);
00175   int ni = static_cast<int>(ni_), nj = static_cast<int>(nj_);
00176   for (int j = 0; j<nj; ++j)
00177     for (int i = 0; i<ni; ++i) {
00178       float vlamb = lambda_/var_view(i,j);
00179       sdet_mrf_site_bp_sptr s =
00180         new sdet_mrf_site_bp(n_labels_,vlamb, truncation_cost_);
00181       s->set_label(scale*(view(i,j)-min_));
00182       sites_[j][i]=s;
00183     }
00184 }
00185 
00186 void sdet_mrf_bp::send_messages_optimized()
00187 {
00188   const int ni = static_cast<int>(ni_), nj = static_cast<int>(nj_);
00189   for (int j = 0; j<nj; ++j)
00190     for (int i = 0; i<ni; ++i) {
00191       //site sending the messages
00192       sdet_mrf_site_bp_sptr sp = sites_[j][i];
00193       for (int n = 0; n<4; ++n) {
00194         int ki = i+di[n], kj = j+dj[n];
00195         if (ki<0||ki>=ni||kj<0||kj>=nj)
00196           continue;
00197         //site receiving a message
00198         sdet_mrf_site_bp_sptr sq = sites_[kj][ki];
00199 
00200         //initialize message with h(fp)
00201         vcl_vector<float> temp(n_labels_), msg;
00202         // minr is the smallest value of h
00203         float minh = vnl_numeric_traits<float>::maxval;
00204         for (unsigned fq = 0; fq<n_labels_; ++fq) {
00205           temp[fq]=sp->h(n, fq);
00206           if (temp[fq]<minh)
00207             minh = temp[fq];
00208         }
00209         // compute the lower bound on msg(q)
00210         msg = lower_envelope_quadratic(kappa_, temp);
00211 
00212         // clamp message value to an upper bound (msg min + disc cost)
00213         minh += discontinuity_cost_;
00214         for (unsigned fq=0; fq<n_labels_; ++fq)
00215           if (msg[fq]>minh)
00216             msg[fq]=minh;
00217 
00218         // normalize message values to prevent divergence
00219         // compute the average message value
00220         float summ = 0.0f;
00221         for (unsigned fq=0; fq<n_labels_; ++fq)
00222           summ += msg[fq];
00223         summ /= n_labels_;
00224 
00225         //subtract the average
00226         for (unsigned fq=0; fq<n_labels_; ++fq) {
00227           float ms = msg[fq]-summ;
00228           //if this assert fails, the number of labels is too large
00229           //compared to the dynamic range of the message elements
00230           assert(ms<=static_cast<float>(vnl_numeric_traits<short>::maxval)
00231               && ms>=-static_cast<float>(vnl_numeric_traits<short>::maxval)-1.f);
00232           sq->set_cur_message(3-n, fq, ms);
00233         }
00234       }
00235     }
00236   //all done sending messages so swap buffers
00237   for (int j = 0; j<nj; ++j)
00238     for (int i = 0; i<ni; ++i) {
00239       sdet_mrf_site_bp_sptr sp = sites_[j][i];
00240       sp->switch_buffers();
00241     }
00242 }
00243 
00244 void sdet_mrf_bp::set_prior_message(unsigned i, unsigned j, unsigned n,
00245                                     vcl_vector<float> const& msg)
00246 {
00247   this->site(i,j)->set_prior_message(n, msg);
00248 }
00249 
00250 void sdet_mrf_bp::print_prior_messages()
00251 {
00252   vcl_cout << "Neighbor layout\n"
00253            << "     0\n"
00254            << "  1  x  2\n"
00255            << "     3\n\n";
00256 
00257   for (unsigned j = 0; j<nj_; ++j)
00258     for (unsigned i = 0; i<ni_; ++i) {
00259       sdet_mrf_site_bp_sptr sp = sites_[j][i];
00260       vcl_cout << " site(" << i << ' ' << j << ")==>\n";
00261       sp->print_prior_messages();
00262     }
00263 }
00264 
00265 void sdet_mrf_bp::print_belief_vectors()
00266 {
00267   for (unsigned j = 0; j<nj_; ++j)
00268     for (unsigned i = 0; i<ni_; ++i) {
00269       sdet_mrf_site_bp_sptr sp = sites_[j][i];
00270       vcl_cout << " site(" << i << ' ' << j << ")==>\n";
00271       sp->print_belief_vector();
00272     }
00273 }
00274 
00275 vil_image_resource_sptr sdet_mrf_bp::belief_image()
00276 {
00277   vil_image_resource_sptr ret = 0;
00278   if (nj_==0||ni_==0)
00279     return ret;
00280   vil_image_view<float> view(ni_, nj_);
00281   if (min_==max_)
00282     return ret;
00283   float scale = (max_-min_)/static_cast<float>(n_labels_-1);
00284   for (unsigned j = 0; j<nj_; ++j)
00285     for (unsigned i = 0; i<ni_; ++i) {
00286       sdet_mrf_site_bp_sptr sp = sites_[j][i];
00287       if (!sp) continue;
00288       float label = static_cast<float>(sp->believed_label());
00289       view(i,j) = scale*label + min_;
00290     }
00291   ret = vil_new_image_resource_of_view(view);
00292   return ret;
00293 }
00294 
00295 vcl_vector<float> sdet_mrf_bp::prior_message(unsigned i, unsigned j, unsigned n)
00296 {
00297   sdet_mrf_site_bp_sptr sp = sites_[j][i];
00298   return sp->prior_message(n);
00299 }
00300 
00301 void sdet_mrf_bp::clear()
00302 {
00303   for (unsigned j = 0; j<nj_; ++j)
00304     for (unsigned i = 0; i<ni_; ++i) {
00305       sdet_mrf_site_bp_sptr sp = sites_[j][i];
00306       if (!sp) continue;
00307       sp->clear();
00308     }
00309 }