Go to the documentation of this file.00001
00002 #ifndef bsta_mixture_fixed_h_
00003 #define bsta_mixture_fixed_h_
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #include "bsta_distribution.h"
00017 #include "bsta_sampler.h"
00018 #include <vcl_cassert.h>
00019 #include <vcl_algorithm.h>
00020 #include <vcl_iostream.h>
00021 #include <vpdl/vpdt/vpdt_dist_traits.h>
00022 #include <vnl/vnl_random.h>
00023
00024
00025 template <class dist_, unsigned s>
00026 class bsta_mixture_fixed : public bsta_distribution<typename dist_::math_type,
00027 dist_::dimension>
00028 {
00029 public:
00030 typedef dist_ dist_type;
00031 typedef dist_type component_type;
00032 enum { max_components = s };
00033
00034 private:
00035 typedef typename dist_::math_type T;
00036 typedef typename dist_::vector_type vector_;
00037
00038
00039
00040 struct component
00041 {
00042
00043 component(): distribution(), weight(T(0)) {}
00044
00045 component(const dist_& d, const T& w = T(0) )
00046 : distribution(d), weight(w) {}
00047
00048
00049 bool operator< (const component& rhs) const
00050 { return this->weight > rhs.weight; }
00051
00052
00053
00054
00055 dist_ distribution;
00056
00057 T weight;
00058 };
00059
00060
00061
00062 class sort_weight
00063 {
00064 public:
00065 bool operator() (const component c1, const component c2) const
00066 { return c1.weight > c2.weight; }
00067 };
00068
00069
00070 template <class comp_type_>
00071 class sort_adaptor
00072 {
00073 public:
00074 sort_adaptor(comp_type_ c) : comp(c) {}
00075 bool operator() (const component& c1, const component& c2) const
00076 { return comp(c1.distribution, c1.weight, c2.distribution, c2.weight); }
00077 comp_type_ comp;
00078 };
00079
00080
00081 component components_[s];
00082
00083 unsigned num_components_;
00084
00085 public:
00086
00087 bsta_mixture_fixed<dist_,s>() : num_components_(0) {}
00088
00089
00090 bsta_mixture_fixed<dist_,s>(const bsta_mixture_fixed<dist_,s>& other)
00091 : num_components_(other.num_components_)
00092 {
00093
00094 for (unsigned int i=0; i<s; ++i){
00095 components_[i] = other.components_[i];
00096 }
00097 }
00098
00099
00100 ~bsta_mixture_fixed<dist_,s>()
00101 {
00102 }
00103
00104
00105 bsta_mixture_fixed<dist_,s>& operator= (const bsta_mixture_fixed<dist_,s>& rhs)
00106 {
00107 if (this != &rhs) {
00108
00109 for (unsigned int i=0; i<s; ++i) {
00110 components_[i] = rhs.components_[i];
00111 }
00112 num_components_ = rhs.num_components_;
00113 }
00114 return *this;
00115 }
00116
00117
00118 unsigned int num_components() const { return num_components_; }
00119
00120
00121 const dist_& distribution(unsigned int index) const
00122 { return components_[index].distribution; }
00123
00124
00125 dist_& distribution(unsigned int index)
00126 { return components_[index].distribution; }
00127
00128
00129 T weight(unsigned int index) const { return components_[index].weight; }
00130
00131
00132 void set_weight(unsigned int index, const T& w) { components_[index].weight = w; }
00133
00134
00135 bool insert(const dist_& d, const T& weight = T(0))
00136 {
00137 if (num_components_ >= s)
00138 return false;
00139
00140 components_[num_components_++] = component(d, weight);
00141 return true;
00142 }
00143
00144
00145 void remove_last() { components_[--num_components_].weight = T(0); }
00146
00147
00148
00149 T prob_density(const vector_& pt) const
00150 {
00151 T prob = 0;
00152
00153 for (unsigned i=0; i<num_components_; ++i)
00154 prob += components_[i].weight
00155 * components_[i].distribution.prob_density(pt);
00156 return prob;
00157 }
00158
00159
00160
00161 T probability(const vector_& min_pt, const vector_& max_pt) const
00162 {
00163 T prob = 0;
00164
00165 for (unsigned i=0; i<num_components_; ++i)
00166 prob += components_[i].weight
00167 * components_[i].distribution.probability(min_pt,max_pt);
00168 return prob;
00169 }
00170
00171
00172
00173 vector_ expected_value()
00174 {
00175 vector_ expected_value(T(0));
00176 for (unsigned i=0; i<num_components_; ++i)
00177 expected_value += components_[i].weight
00178 * components_[i].distribution.mean();
00179 return expected_value;
00180 }
00181
00182
00183 void normalize_weights()
00184 {
00185 T sum = 0;
00186 for (unsigned i=0; i<num_components_; ++i)
00187 sum += components_[i].weight;
00188 assert(sum > 0);
00189 for (unsigned i=0; i<num_components_; ++i)
00190 components_[i].weight /= sum;
00191 }
00192
00193
00194 void sort() { vcl_sort(components_, components_+num_components_, sort_weight() ); }
00195
00196
00197
00198
00199
00200
00201
00202
00203 template <class comp_type_>
00204 void sort(comp_type_ comp)
00205 { vcl_sort(components_, components_+num_components_, sort_adaptor<comp_type_>(comp)); }
00206
00207
00208 template <class comp_type_>
00209 void sort(comp_type_ comp, unsigned int idx)
00210 { assert(idx < s);
00211 vcl_sort(components_, components_+idx+1, sort_adaptor<comp_type_>(comp)); }
00212
00213
00214
00215
00216 vector_ sample(vnl_random& rng) const
00217 {
00218
00219 T sum = 0;
00220 for (unsigned i=0; i<num_components_; ++i)
00221 sum += components_[i].weight;
00222
00223 vcl_vector<float> ps;
00224 vcl_vector<unsigned> ids;
00225 for (unsigned i=0; i<num_components_; ++i) {
00226 float w;
00227 if (sum > 0)
00228 w = float(components_[i].weight/sum);
00229 else
00230 w = float(components_[i].weight);
00231 ps.push_back(w);
00232 ids.push_back(i);
00233 }
00234 vcl_vector<unsigned> out;
00235 bsta_sampler<unsigned>::sample(ids, ps, 1, out, rng);
00236 assert(out.size() == 1);
00237
00238 return components_[out[0]].distribution.sample(rng);
00239
00240 }
00241 };
00242
00243 template <class dist_, unsigned s>
00244 inline vcl_ostream& operator<< (vcl_ostream& os,
00245 bsta_mixture_fixed<dist_,s> const& no)
00246 {
00247 for (unsigned i=0; i<no.num_components(); ++i)
00248 os<<"Component #"<<i<<" weight=: "<<no.weight(i)<<"distribution: "<<no.distribution(i)<<vcl_endl;
00249 return os;
00250 }
00251
00252
00253 template <class dist, unsigned s>
00254 struct vpdt_is_mixture<bsta_mixture_fixed<dist,s> >
00255 {
00256 static const bool value = true;
00257 };
00258
00259 #endif // bsta_mixture_fixed_h_