00001 #include "sdet_mrf_bp.h"
00002 #include <sdet/sdet_mrf_site_bp.h>
00003 #include <vil/vil_convert.h>
00004 #include <vil/vil_math.h>
00005 #include <vil/vil_new.h>
00006 #include <vnl/vnl_numeric_traits.h>
00007 #include <vcl_cassert.h>
00008
00009
00010
00011
00012
00013
00014 static const int di[4]={ 0, -1, 1, 0};
00015 static const int dj[4]={-1, 0, 0, 1};
00016
00017 void lower_envelope_linear(float w, vcl_vector<float>& msg)
00018 {
00019 unsigned nlabels = msg.size();
00020
00021 for (unsigned fq=1; fq<nlabels; ++fq) {
00022 float mfq = msg[fq];
00023 float mfqm1 = msg[fq-1] + w;
00024 if (mfq<mfqm1)
00025 msg[fq]=mfq;
00026 else
00027 msg[fq]=mfqm1;
00028 }
00029
00030 for (int fq=nlabels-2; fq>=0; --fq) {
00031 float mfq = msg[fq];
00032 float mfqp1 = msg[fq+1]+w;
00033 if (mfq<mfqp1)
00034 msg[fq]=mfq;
00035 else
00036 msg[fq]=mfqp1;
00037 }
00038 }
00039
00040 vcl_vector<float> lower_envelope_quadratic(float w,
00041 vcl_vector<float> const& h)
00042 {
00043 int nlabels = h.size();
00044 vcl_vector<float> env_out(nlabels);
00045 vcl_vector<int> v(nlabels);
00046 vcl_vector<float> z(nlabels+1);
00047 int k = 0;
00048 v[0] = 0;
00049 z[0] = -vnl_numeric_traits<float>::maxval;
00050 z[1] = -z[0];
00051
00052 for (int fq = 1; fq <= nlabels-1; ++fq) {
00053 float s = ((h[fq]+w*fq*fq)-(h[v[k]]+w*v[k]*v[k])) / (2*w*(fq-v[k]));
00054 while (s <= z[k]) {
00055 k--;
00056 s = ((h[fq]+fq*fq)-(h[v[k]]+v[k]*v[k])) / (2*(fq-v[k]));
00057 }
00058 ++k;
00059 v[k] = fq;
00060 z[k] = s;
00061 z[k+1] = vnl_numeric_traits<float>::maxval;
00062 }
00063 k = 0;
00064 for (int fq = 0; fq <= nlabels-1; ++fq) {
00065 while (z[k+1] < fq)
00066 ++k;
00067 env_out[fq] = w*(fq-v[k])*(fq-v[k]) + h[v[k]];
00068 }
00069 return env_out;
00070 }
00071
00072 sdet_mrf_bp::sdet_mrf_bp(unsigned ni, unsigned nj,
00073 unsigned n_labels)
00074 : ni_(ni), nj_(nj), n_labels_(n_labels), discontinuity_cost_(1.0f),
00075 truncation_cost_(1.0f), kappa_(1.0f), lambda_(1.0f), min_(0.0f), max_(0.0f)
00076 {
00077 sites_.resize(nj_, ni_);
00078 for (unsigned j = 0; j<nj_; ++j)
00079 for (unsigned i = 0; i<ni_; ++i)
00080 sites_[j][i]=new sdet_mrf_site_bp(n_labels_,lambda_, truncation_cost_);
00081 }
00082
00083 sdet_mrf_bp::sdet_mrf_bp(vil_image_resource_sptr obs_labels, unsigned n_labels,
00084 float discontinuity_cost, float truncation_cost,
00085 float kappa, float lambda)
00086 : ni_(0), nj_(0), n_labels_(n_labels),
00087 discontinuity_cost_(discontinuity_cost),
00088 truncation_cost_(truncation_cost), kappa_(kappa),
00089 lambda_(lambda), min_(0.0f), max_(0.0f)
00090 {
00091 if (!obs_labels) return;
00092 ni_=obs_labels->ni(); nj_=obs_labels->nj();
00093 vil_image_view_base_sptr temp = obs_labels->get_view();
00094 vil_image_view<float> view = *vil_convert_cast(float(), temp);
00095 vil_math_value_range(view, min_, max_);
00096 if (min_ >= max_) return;
00097 sites_.resize(nj_, ni_);
00098 float scale = (n_labels-1)/(max_-min_);
00099 for (unsigned j = 0; j<nj_; ++j)
00100 for (unsigned i = 0; i<ni_; ++i) {
00101 sdet_mrf_site_bp_sptr s =
00102 new sdet_mrf_site_bp(n_labels_, lambda_, truncation_cost_);
00103 s->set_label(scale*(view(i,j)-min_));
00104 sites_[j][i]=s;
00105 }
00106 }
00107
00108 sdet_mrf_bp::sdet_mrf_bp(vil_image_view<float> const& obs_labels,
00109 unsigned n_labels, float discontinuity_cost,
00110 float truncation_cost, float kappa, float lambda)
00111 : n_labels_(n_labels), discontinuity_cost_(discontinuity_cost),
00112 truncation_cost_(truncation_cost), kappa_(kappa), lambda_(lambda),
00113 min_(0.0f), max_(0.0f)
00114 {
00115 if (!obs_labels) return;
00116 ni_=obs_labels.ni(); nj_=obs_labels.nj();
00117 vil_math_value_range(obs_labels, min_, max_);
00118 if (min_ >= max_) return;
00119 sites_.resize(nj_, ni_);
00120 float scale = (n_labels-1)/(max_-min_);
00121 for (unsigned j = 0; j<nj_; ++j)
00122 for (unsigned i = 0; i<ni_; ++i) {
00123 sdet_mrf_site_bp_sptr s =
00124 new sdet_mrf_site_bp(n_labels_, lambda_, truncation_cost_);
00125 s->set_label(scale*(obs_labels(i,j)-min_));
00126 sites_[j][i]=s;
00127 }
00128 }
00129
00130 sdet_mrf_bp::sdet_mrf_bp(vil_image_view<float> const& obs_labels,
00131 vil_image_view<float> const& var, unsigned n_labels,
00132 float discontinuity_cost, float truncation_cost,
00133 float kappa, float lambda)
00134 : n_labels_(n_labels), discontinuity_cost_(discontinuity_cost),
00135 truncation_cost_(truncation_cost), kappa_(kappa),
00136 lambda_(lambda), min_(0.0f), max_(0.0f)
00137 {
00138 ni_=obs_labels.ni(); nj_=obs_labels.nj();
00139 if (!ni_||!nj_) return;
00140 vil_math_value_range(obs_labels, min_, max_);
00141 if (min_ >= max_) return;
00142 sites_.resize(nj_, ni_);
00143 float scale = (n_labels-1)/(max_-min_);
00144 int ni = static_cast<int>(ni_), nj = static_cast<int>(nj_);
00145 for (int j = 0; j<nj; ++j)
00146 for (int i = 0; i<ni; ++i) {
00147 float vlamb = lambda_;
00148 if (var(i,j)>0.0f)
00149 vlamb = lambda_/var(i,j);
00150 sdet_mrf_site_bp_sptr s =
00151 new sdet_mrf_site_bp(n_labels_,vlamb, truncation_cost_);
00152 s->set_label(scale*(obs_labels(i,j)-min_));
00153 sites_[j][i]=s;
00154 }
00155 }
00156
00157 sdet_mrf_bp::sdet_mrf_bp(vil_image_resource_sptr obs_labels,
00158 vil_image_resource_sptr var,
00159 unsigned n_labels, float discontinuity_cost,
00160 float truncation_cost, float kappa, float lambda)
00161 : n_labels_(n_labels),discontinuity_cost_(discontinuity_cost),
00162 truncation_cost_(truncation_cost), kappa_(kappa), lambda_(lambda),
00163 min_(0.0f), max_(0.0f)
00164 {
00165 if (!obs_labels) return;
00166 ni_=obs_labels->ni(); nj_=obs_labels->nj();
00167 vil_image_view_base_sptr temp = obs_labels->get_view();
00168 vil_image_view<float> view = *vil_convert_cast(float(), temp);
00169 vil_math_value_range(view, min_, max_);
00170 if (min_ >= max_) return;
00171 sites_.resize(nj_, ni_);
00172 float scale = (n_labels-1)/(max_-min_);
00173 vil_image_view_base_sptr tempv = var->get_view();
00174 vil_image_view<float> var_view = *vil_convert_cast(float(), temp);
00175 int ni = static_cast<int>(ni_), nj = static_cast<int>(nj_);
00176 for (int j = 0; j<nj; ++j)
00177 for (int i = 0; i<ni; ++i) {
00178 float vlamb = lambda_/var_view(i,j);
00179 sdet_mrf_site_bp_sptr s =
00180 new sdet_mrf_site_bp(n_labels_,vlamb, truncation_cost_);
00181 s->set_label(scale*(view(i,j)-min_));
00182 sites_[j][i]=s;
00183 }
00184 }
00185
00186 void sdet_mrf_bp::send_messages_optimized()
00187 {
00188 const int ni = static_cast<int>(ni_), nj = static_cast<int>(nj_);
00189 for (int j = 0; j<nj; ++j)
00190 for (int i = 0; i<ni; ++i) {
00191
00192 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00193 for (int n = 0; n<4; ++n) {
00194 int ki = i+di[n], kj = j+dj[n];
00195 if (ki<0||ki>=ni||kj<0||kj>=nj)
00196 continue;
00197
00198 sdet_mrf_site_bp_sptr sq = sites_[kj][ki];
00199
00200
00201 vcl_vector<float> temp(n_labels_), msg;
00202
00203 float minh = vnl_numeric_traits<float>::maxval;
00204 for (unsigned fq = 0; fq<n_labels_; ++fq) {
00205 temp[fq]=sp->h(n, fq);
00206 if (temp[fq]<minh)
00207 minh = temp[fq];
00208 }
00209
00210 msg = lower_envelope_quadratic(kappa_, temp);
00211
00212
00213 minh += discontinuity_cost_;
00214 for (unsigned fq=0; fq<n_labels_; ++fq)
00215 if (msg[fq]>minh)
00216 msg[fq]=minh;
00217
00218
00219
00220 float summ = 0.0f;
00221 for (unsigned fq=0; fq<n_labels_; ++fq)
00222 summ += msg[fq];
00223 summ /= n_labels_;
00224
00225
00226 for (unsigned fq=0; fq<n_labels_; ++fq) {
00227 float ms = msg[fq]-summ;
00228
00229
00230 assert(ms<=static_cast<float>(vnl_numeric_traits<short>::maxval)
00231 && ms>=-static_cast<float>(vnl_numeric_traits<short>::maxval)-1.f);
00232 sq->set_cur_message(3-n, fq, ms);
00233 }
00234 }
00235 }
00236
00237 for (int j = 0; j<nj; ++j)
00238 for (int i = 0; i<ni; ++i) {
00239 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00240 sp->switch_buffers();
00241 }
00242 }
00243
00244 void sdet_mrf_bp::set_prior_message(unsigned i, unsigned j, unsigned n,
00245 vcl_vector<float> const& msg)
00246 {
00247 this->site(i,j)->set_prior_message(n, msg);
00248 }
00249
00250 void sdet_mrf_bp::print_prior_messages()
00251 {
00252 vcl_cout << "Neighbor layout\n"
00253 << " 0\n"
00254 << " 1 x 2\n"
00255 << " 3\n\n";
00256
00257 for (unsigned j = 0; j<nj_; ++j)
00258 for (unsigned i = 0; i<ni_; ++i) {
00259 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00260 vcl_cout << " site(" << i << ' ' << j << ")==>\n";
00261 sp->print_prior_messages();
00262 }
00263 }
00264
00265 void sdet_mrf_bp::print_belief_vectors()
00266 {
00267 for (unsigned j = 0; j<nj_; ++j)
00268 for (unsigned i = 0; i<ni_; ++i) {
00269 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00270 vcl_cout << " site(" << i << ' ' << j << ")==>\n";
00271 sp->print_belief_vector();
00272 }
00273 }
00274
00275 vil_image_resource_sptr sdet_mrf_bp::belief_image()
00276 {
00277 vil_image_resource_sptr ret = 0;
00278 if (nj_==0||ni_==0)
00279 return ret;
00280 vil_image_view<float> view(ni_, nj_);
00281 if (min_==max_)
00282 return ret;
00283 float scale = (max_-min_)/static_cast<float>(n_labels_-1);
00284 for (unsigned j = 0; j<nj_; ++j)
00285 for (unsigned i = 0; i<ni_; ++i) {
00286 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00287 if (!sp) continue;
00288 float label = static_cast<float>(sp->believed_label());
00289 view(i,j) = scale*label + min_;
00290 }
00291 ret = vil_new_image_resource_of_view(view);
00292 return ret;
00293 }
00294
00295 vcl_vector<float> sdet_mrf_bp::prior_message(unsigned i, unsigned j, unsigned n)
00296 {
00297 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00298 return sp->prior_message(n);
00299 }
00300
00301 void sdet_mrf_bp::clear()
00302 {
00303 for (unsigned j = 0; j<nj_; ++j)
00304 for (unsigned i = 0; i<ni_; ++i) {
00305 sdet_mrf_site_bp_sptr sp = sites_[j][i];
00306 if (!sp) continue;
00307 sp->clear();
00308 }
00309 }