00001
00002 #ifndef bsta_adaptive_updater_h_
00003 #define bsta_adaptive_updater_h_
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <bsta/bsta_distribution.h>
00020 #include <bsta/bsta_mixture.h>
00021 #include <bsta/bsta_mixture_fixed.h>
00022 #include <bsta/bsta_attributes.h>
00023 #include <bsta/bsta_gauss_ff3.h>
00024 #include "bsta_gaussian_updater.h"
00025 #include "bsta_gaussian_stats.h"
00026
00027
00028
00029
00030 template <class mix_dist_>
00031 class bsta_mg_adaptive_updater
00032 {
00033 private:
00034 typedef typename mix_dist_::dist_type obs_gaussian_;
00035 typedef typename obs_gaussian_::contained_type gaussian_;
00036 typedef typename gaussian_::math_type T;
00037 typedef typename gaussian_::vector_type vector_;
00038
00039 public:
00040
00041 typedef typename gaussian_::field_type field_type;
00042 typedef mix_dist_ distribution_type;
00043
00044 protected:
00045
00046 bsta_mg_adaptive_updater(const gaussian_& model,
00047 unsigned int max_cmp = 5)
00048 : init_gaussian_(model,T(1)),
00049 max_components_(max_cmp) {}
00050
00051
00052 void insert(mix_dist_& mixture, const vector_& sample, T init_weight) const
00053 {
00054 bool removed = false;
00055 if (mixture.num_components() >= max_components_){
00056 removed = true;
00057 do {
00058 mixture.remove_last();
00059 } while (mixture.num_components() >= max_components_);
00060 }
00061
00062
00063 if (removed){
00064 T adjust = T(0);
00065 for (unsigned int i=0; i<mixture.num_components(); ++i)
00066 adjust += mixture.weight(i);
00067 adjust = (T(1)-init_weight) / adjust;
00068 for (unsigned int i=0; i<mixture.num_components(); ++i)
00069 mixture.set_weight(i, mixture.weight(i)*adjust);
00070 }
00071 init_gaussian_.set_mean(sample);
00072
00073
00074 if(mixture.num_components()>0)
00075 mixture.insert(init_gaussian_,init_weight);
00076 else
00077 mixture.insert(init_gaussian_,T(1));
00078 }
00079
00080
00081 mutable obs_gaussian_ init_gaussian_;
00082
00083 unsigned int max_components_;
00084 };
00085
00086
00087
00088
00089 template <class mix_dist_>
00090 class bsta_mg_statistical_updater : public bsta_mg_adaptive_updater<mix_dist_>
00091 {
00092 public:
00093 typedef typename mix_dist_::dist_type obs_gaussian_;
00094 typedef typename obs_gaussian_::contained_type gaussian_;
00095 typedef typename gaussian_::math_type T;
00096 typedef typename gaussian_::vector_type vector_;
00097 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00098
00099
00100 typedef obs_mix_dist_ distribution_type;
00101
00102 enum { data_dimension = gaussian_::dimension };
00103
00104
00105 bsta_mg_statistical_updater(const gaussian_& model,
00106 unsigned int max_cmp = 5,
00107 T g_thresh = T(3),
00108 T min_stdev = T(0))
00109 : bsta_mg_adaptive_updater<mix_dist_>(model, max_cmp),
00110 gt2_(g_thresh*g_thresh), min_var_(min_stdev*min_stdev) {}
00111
00112
00113 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00114 {
00115 mix.num_observations += T(1);
00116 this->update(mix, sample, T(1)/mix.num_observations);
00117 }
00118
00119 void update( mix_dist_& mix, const vector_& sample, T alpha ) const;
00120 #if 0
00121 void update( mix_dist_& mix, const T & sample, T alpha ) const;
00122 #endif
00123
00124 T gt2_;
00125
00126 T min_var_;
00127 };
00128
00129
00130
00131
00132 template <class mix_dist_>
00133 class bsta_mg_window_updater : public bsta_mg_statistical_updater<mix_dist_>
00134 {
00135 public:
00136 typedef typename mix_dist_::dist_type obs_gaussian_;
00137 typedef typename obs_gaussian_::contained_type gaussian_;
00138 typedef typename gaussian_::math_type T;
00139 typedef typename gaussian_::vector_type vector_;
00140 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00141
00142
00143 typedef obs_mix_dist_ distribution_type;
00144
00145 enum { data_dimension = gaussian_::dimension };
00146
00147
00148 bsta_mg_window_updater(const gaussian_& model,
00149 unsigned int max_cmp = 5,
00150 T g_thresh = T(3),
00151 T min_stdev = T(0),
00152 unsigned int window_size = 40)
00153 : bsta_mg_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev),
00154 window_size_(window_size) {}
00155
00156
00157 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00158 {
00159 if (mix.num_observations < window_size_)
00160 mix.num_observations += T(1);
00161 this->update(mix, sample, T(1)/mix.num_observations);
00162 }
00163
00164 protected:
00165 unsigned int window_size_;
00166 };
00167
00168
00169
00170 template <class mix_dist_>
00171 class bsta_mg_weighted_updater : bsta_mg_statistical_updater<mix_dist_>
00172 {
00173 public:
00174 typedef typename mix_dist_::dist_type obs_gaussian_;
00175 typedef typename obs_gaussian_::contained_type gaussian_;
00176 typedef typename gaussian_::math_type T;
00177 typedef typename gaussian_::vector_type vector_;
00178 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00179
00180
00181 typedef obs_mix_dist_ distribution_type;
00182
00183 enum { data_dimension = gaussian_::dimension };
00184
00185
00186 bsta_mg_weighted_updater(const gaussian_& model,
00187 unsigned int max_cmp = 5,
00188 T g_thresh = T(3),
00189 T min_stdev = T(0))
00190 : bsta_mg_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev){}
00191
00192
00193 void operator() ( obs_mix_dist_& mix, const vector_& sample, const T weight ) const
00194 {
00195 mix.num_observations += weight;
00196 this->update(mix, sample, weight/mix.num_observations);
00197 }
00198 };
00199
00200
00201
00202
00203 template <class mix_dist_>
00204 class bsta_mg_grimson_statistical_updater : public bsta_mg_adaptive_updater<mix_dist_>
00205 {
00206 public:
00207 typedef typename mix_dist_::dist_type obs_gaussian_;
00208 typedef typename obs_gaussian_::contained_type gaussian_;
00209 typedef typename gaussian_::math_type T;
00210 typedef typename gaussian_::vector_type vector_;
00211 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00212
00213
00214 typedef obs_mix_dist_ distribution_type;
00215
00216 enum { data_dimension = gaussian_::dimension };
00217
00218
00219 bsta_mg_grimson_statistical_updater(const gaussian_& model,
00220 unsigned int max_cmp = 5,
00221 T g_thresh = T(3),
00222 T min_stdev = T(0) )
00223 : bsta_mg_adaptive_updater<mix_dist_>(model, max_cmp),
00224 gt2_(g_thresh*g_thresh), min_var_(min_stdev*min_stdev) {}
00225
00226
00227 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00228 {
00229 mix.num_observations += T(1);
00230 this->update(mix, sample, T(1)/mix.num_observations);
00231 }
00232
00233 void update( mix_dist_& mix, const vector_& sample, T alpha ) const;
00234
00235
00236 T gt2_;
00237
00238 T min_var_;
00239 };
00240
00241
00242
00243 template <class mix_dist_>
00244 class bsta_mg_grimson_window_updater : public bsta_mg_grimson_statistical_updater<mix_dist_>
00245 {
00246 public:
00247 typedef typename mix_dist_::dist_type obs_gaussian_;
00248 typedef typename obs_gaussian_::contained_type gaussian_;
00249 typedef typename gaussian_::math_type T;
00250 typedef typename gaussian_::vector_type vector_;
00251 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00252
00253
00254 typedef obs_mix_dist_ distribution_type;
00255
00256 enum { data_dimension = gaussian_::dimension };
00257
00258
00259 bsta_mg_grimson_window_updater(const gaussian_& model,
00260 unsigned int max_cmp = 5,
00261 T g_thresh = T(3),
00262 T min_stdev = T(0),
00263 unsigned int window_size = 40)
00264 : bsta_mg_grimson_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev),
00265 window_size_(window_size) {}
00266
00267
00268 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00269 {
00270 if (mix.num_observations < window_size_)
00271 mix.num_observations += T(1);
00272 this->update(mix, sample, T(1)/mix.num_observations);
00273 }
00274
00275 protected:
00276 unsigned int window_size_;
00277 };
00278
00279
00280
00281 template <class mix_dist_>
00282 class bsta_mg_grimson_weighted_updater : bsta_mg_grimson_statistical_updater<mix_dist_>
00283 {
00284 public:
00285 typedef typename mix_dist_::dist_type obs_gaussian_;
00286 typedef typename obs_gaussian_::contained_type gaussian_;
00287 typedef typename gaussian_::math_type T;
00288 typedef typename gaussian_::vector_type vector_;
00289 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00290
00291
00292 typedef obs_mix_dist_ distribution_type;
00293
00294 enum { data_dimension = gaussian_::dimension };
00295
00296
00297 bsta_mg_grimson_weighted_updater(const gaussian_& model,
00298 unsigned int max_cmp = 5,
00299 T g_thresh = T(3),
00300 T min_stdev = T(0) )
00301 : bsta_mg_grimson_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev){}
00302
00303
00304 void operator() ( obs_mix_dist_& mix, const vector_& sample, const T weight ) const
00305 {
00306 mix.num_observations += weight;
00307 this->update(mix, sample, weight/mix.num_observations);
00308 }
00309 };
00310
00311
00312 #endif // bsta_adaptive_updater_h_