Go to the documentation of this file.00001
00002 #ifndef bsta_beta_updater_h_
00003 #define bsta_beta_updater_h_
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <bsta/bsta_beta.h>
00023 #include <bsta/bsta_attributes.h>
00024 #include <vcl_algorithm.h>
00025 #include <vcl_iostream.h>
00026
00027
00028
00029 template <class T>
00030 void bsta_update_beta(bsta_beta<T>& beta_dist, T rho, const T& sample )
00031 {
00032
00033 T rho_comp = 1 - rho;
00034 T old_mean;
00035 #if 0
00036 if (beta_dist.alpha()<1)
00037 old_mean=2*(beta_dist.alpha()-0.5);
00038 else if (beta_dist.beta()<1)
00039 old_mean=1-2*(beta_dist.beta()-0.5);
00040 else
00041 #endif // 0
00042 old_mean = beta_dist.mean();
00043
00044 T diff = sample - old_mean;
00045 T new_var = rho_comp * beta_dist.var();
00046 new_var += (rho * rho_comp) * diff*diff;
00047
00048 T new_mean = (old_mean) + (rho * diff);
00049
00050 T alpha,beta;
00051 if (!bsta_beta<T>::bsta_beta_from_moments(new_mean,new_var,alpha,beta))
00052 return;
00053
00054
00055
00056
00057 if (alpha<1 && beta <1)
00058 vcl_cout<<"Mean : "<<new_mean<< " Var: "<<new_var<<'\n';
00059 beta_dist.set_alpha_beta(alpha, beta);
00060 }
00061
00062 template <class T>
00063 void bsta_update_beta(bsta_beta<T>& beta_dist, T rho, const T& sample , const T & min_var)
00064 {
00065
00066 T rho_comp = 1.0f - rho;
00067
00068 T old_mean;
00069 if (beta_dist.alpha()<1)
00070 old_mean=2*(beta_dist.alpha()-0.5);
00071 else if (beta_dist.beta()<1)
00072 old_mean=1-2*(beta_dist.beta()-0.5);
00073 else
00074 old_mean = beta_dist.mean();
00075
00076 T diff = sample - old_mean;
00077 T new_var = rho_comp * beta_dist.var();
00078 new_var += (rho * rho_comp) * diff*diff;
00079
00080 new_var=vnl_math_max(new_var,min_var);
00081 T new_mean = (old_mean) + (rho * diff);
00082
00083 T alpha,beta;
00084 if (!bsta_beta<T>::bsta_beta_from_moments(new_mean,new_var,alpha,beta))
00085 return;
00086
00087 beta_dist.set_alpha_beta(alpha, beta);
00088 }
00089
00090 template <class beta_>
00091 struct bsta_beta_fitness
00092 {
00093 private:
00094 typedef typename beta_::math_type T;
00095 enum { n = beta_::dimension };
00096 public:
00097 static bool order (const beta_& , const T& w1,
00098 const beta_& , const T& w2)
00099 {
00100 return w1>w2;
00101 }
00102 };
00103
00104
00105 template <class beta_>
00106 class bsta_beta_updater
00107 {
00108 typedef bsta_num_obs<beta_> obs_beta_;
00109 typedef typename beta_::math_type T;
00110 typedef typename beta_::vector_type vector_;
00111 public:
00112
00113
00114 typedef typename beta_::field_type field_type;
00115 typedef beta_ distribution_type;
00116
00117
00118
00119
00120 void operator() ( obs_beta_& d, const vector_& sample ) const
00121 {
00122 d.num_observations += T(1);
00123 bsta_update_beta(d, T(1)/d.num_observations, sample);
00124 }
00125 };
00126
00127 template <class mix_dist_>
00128 class bsta_mix_beta_updater
00129 {
00130 typedef typename mix_dist_::dist_type obs_dist_;
00131 typedef typename obs_dist_::contained_type dist_;
00132 typedef typename dist_::math_type T;
00133 typedef typename dist_::vector_type vector_;
00134 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00135
00136 public:
00137
00138 bsta_mix_beta_updater(const dist_& model, T thresh, T var, unsigned int max_cmp = 5)
00139 : init_dist_(model,T(1)), max_components_(max_cmp), p_thresh_(thresh), var_(var) {}
00140
00141
00142 typedef typename dist_::field_type field_type;
00143 typedef mix_dist_ distribution_type;
00144
00145
00146 void operator() ( obs_mix_dist_& mix, const vector_& sample) const
00147 {
00148 mix.num_observations += T(1);
00149 this->update(mix, sample, T(1)/mix.num_observations);
00150 }
00151
00152 void update( mix_dist_& mix, const vector_& sample, T alpha ) const;
00153
00154 protected:
00155
00156
00157 void insert(mix_dist_& mixture, const vector_& sample, T init_weight) const
00158 {
00159 bool removed = mixture.num_components() >= max_components_;
00160 while (mixture.num_components() >= max_components_)
00161 {
00162 mixture.remove_last();
00163 }
00164
00165
00166 if (removed) {
00167 T adjust = T(0);
00168 for (unsigned int i=0; i<mixture.num_components(); ++i)
00169 adjust += mixture.weight(i);
00170 adjust = (T(1)-init_weight) / adjust;
00171 for (unsigned int i=0; i<mixture.num_components(); ++i)
00172 mixture.set_weight(i, mixture.weight(i)*adjust);
00173 }
00174
00175
00176 #if 0
00177 T t = (sample*(1-sample)/var_)-1;
00178 T alpha=sample*t;
00179 T beta=(1-sample)*t;
00180 init_dist_.set_alpha_beta(alpha,beta);
00181 #endif
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191 T alpha, beta;
00192 bsta_beta<T>::bsta_beta_from_moments(sample, var_,alpha, beta);
00193 init_dist_.set_alpha_beta(alpha,beta);
00194 mixture.insert(init_dist_,init_weight);
00195 }
00196
00197
00198 mutable obs_dist_ init_dist_;
00199
00200 unsigned int max_components_;
00201
00202 T p_thresh_;
00203 T var_;
00204 };
00205
00206 #endif