Go to the documentation of this file.00001
00002 #ifndef bsta_mixture_h_
00003 #define bsta_mixture_h_
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #include "bsta_distribution.h"
00017 #include <vcl_cassert.h>
00018 #include <vcl_vector.h>
00019 #include <vcl_algorithm.h>
00020 #include <vcl_iostream.h>
00021 #include "bsta_sampler.h"
00022 #include <vpdl/vpdt/vpdt_dist_traits.h>
00023 #include <vnl/vnl_random.h>
00024
00025
00026 template <class dist_>
00027 class bsta_mixture : public bsta_distribution<typename dist_::math_type,
00028 dist_::dimension>
00029 {
00030 public:
00031 typedef dist_ dist_type;
00032 typedef dist_type component_type;
00033
00034 enum { max_components = 0 };
00035
00036 private:
00037 typedef typename dist_::math_type T;
00038 typedef typename dist_::vector_type vector_;
00039
00040
00041
00042
00043
00044
00045 struct component
00046 {
00047
00048 component(): distribution(), weight(T(0)) {}
00049
00050 component(const dist_& d, const T& w = T(0) )
00051 : distribution(d), weight(w) {}
00052
00053
00054 bool operator< (const component& rhs) const
00055 { return this->weight > rhs.weight; }
00056
00057
00058
00059
00060 dist_ distribution;
00061
00062 T weight;
00063 };
00064
00065
00066
00067 class sort_weight
00068 {
00069 public:
00070 bool operator() (const component* c1, const component* c2) const
00071 { return c1->weight > c2->weight; }
00072 };
00073
00074
00075 template <class comp_type_>
00076 class sort_adaptor
00077 {
00078 public:
00079 sort_adaptor(comp_type_ c) : comp(c) {}
00080 bool operator() (const component* const c1, const component* const c2) const
00081 { return comp(c1->distribution, c1->weight, c2->distribution, c2->weight); }
00082 comp_type_ comp;
00083 };
00084
00085
00086 vcl_vector<component*> components_;
00087
00088 public:
00089
00090 bsta_mixture<dist_>() {}
00091
00092
00093 bsta_mixture<dist_>(const bsta_mixture<dist_>& other)
00094 : components_(other.components_.size(),NULL)
00095 {
00096
00097 for (unsigned int i=0; i<components_.size(); ++i){
00098 components_[i] = new component(*other.components_[i]);
00099 }
00100 }
00101
00102
00103 ~bsta_mixture<dist_>()
00104 {
00105 for (unsigned int i=0; i<components_.size(); ++i){
00106 delete components_[i];
00107 }
00108 }
00109
00110
00111 bsta_mixture<dist_>& operator= (const bsta_mixture<dist_>& rhs)
00112 {
00113 if (this != &rhs){
00114 for (unsigned int i=0; i<components_.size(); ++i){
00115 delete components_[i];
00116 }
00117 components_.clear();
00118 for (unsigned int i=0; i<rhs.components_.size(); ++i){
00119 components_.push_back(new component(*rhs.components_[i]));
00120 }
00121 }
00122 return *this;
00123 }
00124
00125
00126 unsigned int num_components() const { return components_.size(); }
00127
00128
00129 const dist_& distribution(unsigned int index) const
00130 { return components_[index]->distribution; }
00131
00132
00133 dist_& distribution(unsigned int index)
00134 { return components_[index]->distribution; }
00135
00136
00137 T weight(unsigned int index) const { return components_[index]->weight; }
00138
00139
00140 void set_weight(unsigned int index, const T& w)
00141 { components_[index]->weight = w; }
00142
00143
00144 bool insert(const dist_& d, const T& weight = T(0))
00145 { components_.push_back(new component(d, weight)); return true; }
00146
00147
00148 void remove_last() { delete components_.back(); components_.pop_back(); }
00149
00150
00151
00152 T prob_density(const vector_& pt) const
00153 {
00154 typedef typename vcl_vector<component*>::const_iterator comp_itr;
00155 T prob = 0;
00156 for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00157 prob += (*i)->weight * (*i)->distribution.prob_density(pt);
00158 return prob;
00159 }
00160
00161
00162
00163 T probability(const vector_& min_pt, const vector_& max_pt) const
00164 {
00165 typedef typename vcl_vector<component*>::const_iterator comp_itr;
00166 T prob = 0;
00167 for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00168 prob += (*i)->weight * (*i)->distribution.probability(min_pt,max_pt);
00169 return prob;
00170 }
00171
00172
00173 void normalize_weights()
00174 {
00175 typedef typename vcl_vector<component*>::iterator comp_itr;
00176 T sum = 0;
00177 for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00178 sum += (*i)->weight;
00179 assert(sum > 0);
00180 for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00181 (*i)->weight /= sum;
00182 }
00183
00184
00185 void sort() { vcl_sort(components_.begin(), components_.end(), sort_weight() ); }
00186
00187
00188
00189
00190
00191
00192
00193
00194 template <class comp_type_>
00195 void sort(comp_type_ comp)
00196 { vcl_sort(components_.begin(), components_.end(), sort_adaptor<comp_type_>(comp)); }
00197
00198 template <class comp_type_>
00199 void sort(comp_type_ comp, unsigned int idx)
00200 { vcl_sort(components_.begin(), components_.begin()+idx+1, sort_adaptor<comp_type_>(comp)); }
00201
00202
00203
00204
00205 vector_ sample(vnl_random& rng) const {
00206
00207 T sum = 0;
00208 for (unsigned i=0; i<num_components(); ++i)
00209 sum += components_[i].weight;
00210
00211 vcl_vector<float> ps;
00212 vcl_vector<unsigned> ids;
00213 for (unsigned i=0; i<num_components(); ++i) {
00214 float w;
00215 if (sum > 0)
00216 w = float(components_[i].weight/sum);
00217 else
00218 w = float(components_[i].weight);
00219 ps.push_back(w);
00220 ids.push_back(i);
00221 }
00222 vcl_vector<unsigned> out;
00223 bsta_sampler<unsigned>::sample(ids, ps, 1, out);
00224 assert(out.size() == 1);
00225
00226 return components_[out[0]].distribution.sample(rng);
00227
00228 }
00229 };
00230
00231 template <class dist_>
00232 inline vcl_ostream& operator<< (vcl_ostream& os,
00233 bsta_mixture<dist_> const& m)
00234 {
00235 typedef typename dist_::math_type T;
00236 unsigned n = m.num_components();
00237 for (unsigned c = 0; c<n; ++c){
00238 const dist_& mc = m.distribution(c);
00239 T weight = m.weight(c);
00240 os << "mixture_comp["<< c << "]wgt(" << weight << ")\n" << mc << '\n';
00241 }
00242 return os;
00243 }
00244
00245
00246 template <class dist>
00247 struct vpdt_is_mixture<bsta_mixture<dist> >
00248 {
00249 static const bool value = true;
00250 };
00251
00252 #endif // bsta_mixture_h_