contrib/brl/bbas/bsta/algo/bsta_beta_updater.h
Go to the documentation of this file.
00001 // This is brl/bbas/bsta/algo/bsta_beta_updater.h
00002 #ifndef bsta_beta_updater_h_
00003 #define bsta_beta_updater_h_
00004 //:
00005 // \file
00006 // \brief  Iterative updating of beta distribution
00007 // \author Gamze Tunali (gtunali@brown.edu)
00008 // \date   Nov 17, 2009
00009 //
00010 // In this implementation $\alpha>=1$ and $\beta>=1$.
00011 // In order to ensure this
00012 // $ \mu(\mu(1-\mu)/var-1)>1 $ and
00013 // $ (1-\mu)(\mu(1-\mu)/var-1)>1 $
00014 //
00015 // The distance of beta distribution is given as
00016 // $$  -(\alpha-1)log(x/\mu)-(\beta-1)\log((1-x)/(1-\mu)) > 3  $$
00017 //
00018 // \verbatim
00019 //  Modifications
00020 // \endverbatim
00021 
00022 #include <bsta/bsta_beta.h>
00023 #include <bsta/bsta_attributes.h>
00024 #include <vcl_algorithm.h>
00025 #include <vcl_iostream.h>
00026 
00027 //: Update the statistics given a 1D beta distribution and a learning rate
00028 // \note if rho = 1/(num observations) then this is just an online cumulative average
00029 template <class T>
00030 void bsta_update_beta(bsta_beta<T>& beta_dist, T rho, const T& sample )
00031 {
00032   // the complement of rho (i.e. rho+rho_comp=1)
00033   T rho_comp = 1 - rho;
00034   T old_mean;
00035 #if 0
00036   if (beta_dist.alpha()<1)
00037     old_mean=2*(beta_dist.alpha()-0.5);
00038   else if (beta_dist.beta()<1)
00039     old_mean=1-2*(beta_dist.beta()-0.5);
00040   else
00041 #endif // 0
00042     old_mean = beta_dist.mean();
00043 
00044   T diff = sample - old_mean;
00045   T new_var = rho_comp * beta_dist.var();
00046   new_var += (rho * rho_comp) * diff*diff;
00047 
00048   T new_mean = (old_mean) +  (rho * diff);
00049 
00050   T alpha,beta;
00051   if (!bsta_beta<T>::bsta_beta_from_moments(new_mean,new_var,alpha,beta))
00052     return;
00053   //T t = (new_mean*(1-new_mean)/new_var)-1;
00054   //T alpha=new_mean*t;
00055   //T beta=(1-new_mean)*t;
00056 
00057   if (alpha<1 && beta <1)
00058     vcl_cout<<"Mean : "<<new_mean<< "  Var: "<<new_var<<'\n';
00059   beta_dist.set_alpha_beta(alpha, beta);
00060 }
00061 
00062 template <class T>
00063 void bsta_update_beta(bsta_beta<T>& beta_dist, T rho, const T& sample , const T & min_var)
00064 {
00065   // the complement of rho (i.e. rho+rho_comp=1.0)
00066   T rho_comp = 1.0f - rho;
00067 
00068   T old_mean;
00069   if (beta_dist.alpha()<1)
00070     old_mean=2*(beta_dist.alpha()-0.5);
00071   else if (beta_dist.beta()<1)
00072     old_mean=1-2*(beta_dist.beta()-0.5);
00073   else
00074     old_mean = beta_dist.mean();
00075 
00076   T diff = sample - old_mean;
00077   T new_var = rho_comp * beta_dist.var();
00078   new_var += (rho * rho_comp) * diff*diff;
00079 
00080   new_var=vnl_math_max(new_var,min_var);
00081   T new_mean = (old_mean) +  (rho * diff);
00082 
00083   T alpha,beta;
00084   if (!bsta_beta<T>::bsta_beta_from_moments(new_mean,new_var,alpha,beta))
00085     return;
00086 
00087   beta_dist.set_alpha_beta(alpha, beta);
00088 }
00089 
00090 template <class beta_>
00091 struct bsta_beta_fitness
00092 {
00093  private:
00094   typedef typename beta_::math_type T;
00095   enum { n = beta_::dimension };
00096  public:
00097   static bool order (const beta_& /*d1*/, const T& w1,
00098                      const beta_& /*d2*/, const T& w2)
00099   {
00100     return w1>w2;
00101   }
00102 };
00103 
00104 //: An updater for statistically updating beta distributions
00105 template <class beta_>
00106 class bsta_beta_updater
00107 {
00108   typedef bsta_num_obs<beta_> obs_beta_;
00109   typedef typename beta_::math_type T;
00110   typedef typename beta_::vector_type vector_;
00111  public:
00112 
00113   //: for compatibility with vpdl/vpdt
00114   typedef typename beta_::field_type field_type;
00115   typedef beta_ distribution_type;
00116 
00117 
00118   //: The main function
00119   // make the appropriate type casts and call a helper function
00120   void operator() ( obs_beta_& d, const vector_& sample ) const
00121   {
00122     d.num_observations += T(1);
00123     bsta_update_beta(d, T(1)/d.num_observations, sample);
00124   }
00125 };
00126 
00127 template <class mix_dist_>
00128 class bsta_mix_beta_updater
00129 {
00130   typedef typename mix_dist_::dist_type obs_dist_; //mixture comp type
00131   typedef typename obs_dist_::contained_type dist_; //num_obs parent
00132   typedef typename dist_::math_type T;//the field type, e.g. float
00133   typedef typename dist_::vector_type vector_;// the vector type
00134   typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00135 
00136  public:
00137   //: Constructor
00138   bsta_mix_beta_updater(const dist_& model, T thresh,  T var, unsigned int max_cmp = 5)
00139    : init_dist_(model,T(1)), max_components_(max_cmp), p_thresh_(thresh), var_(var) {}
00140 
00141   //: for compatibility with vpdl/vpdt
00142   typedef typename dist_::field_type field_type;
00143   typedef mix_dist_ distribution_type;
00144 
00145   //: The main function
00146   void operator() ( obs_mix_dist_& mix, const vector_& sample) const
00147   {
00148     mix.num_observations += T(1);
00149     this->update(mix, sample, T(1)/mix.num_observations);
00150   }
00151 
00152   void update( mix_dist_& mix, const vector_& sample, T alpha ) const;
00153 
00154  protected:
00155 
00156   //: insert a sample in the mixture
00157   void insert(mix_dist_& mixture, const vector_& sample, T init_weight) const
00158   {
00159     bool removed = mixture.num_components() >= max_components_;
00160     while (mixture.num_components() >= max_components_)
00161     {
00162       mixture.remove_last();
00163     }
00164 
00165     // if a mixture is removed renormalize the rest
00166     if (removed) {
00167       T adjust = T(0);
00168       for (unsigned int i=0; i<mixture.num_components(); ++i)
00169         adjust += mixture.weight(i);
00170       adjust = (T(1)-init_weight) / adjust;
00171       for (unsigned int i=0; i<mixture.num_components(); ++i)
00172         mixture.set_weight(i, mixture.weight(i)*adjust);
00173     }
00174 
00175     //T var = T(0.05);
00176 #if 0
00177     T t = (sample*(1-sample)/var_)-1;
00178     T alpha=sample*t;
00179     T beta=(1-sample)*t;
00180     init_dist_.set_alpha_beta(alpha,beta); ///??? this was setting mean
00181 #endif
00182     //T lower = T(0.5-vcl_sqrt(1-4*var_)/2);
00183     //T upper = T(0.5+vcl_sqrt(1-4*var_)/2);
00184 
00185     //vector_ val = sample;
00186     //if (sample < lower)
00187     //    val = lower+T(1e-6);
00188     //else if (sample > upper)
00189     //    val = upper-T(1e-6);
00190 
00191     T alpha, beta;
00192     bsta_beta<T>::bsta_beta_from_moments(sample, var_,alpha, beta);
00193     init_dist_.set_alpha_beta(alpha,beta);
00194     mixture.insert(init_dist_,init_weight);
00195   }
00196 
00197   //: A model for new beta inserted
00198   mutable obs_dist_ init_dist_;
00199   //: The maximum number of components in the mixture
00200   unsigned int max_components_;
00201   //: probability threshold
00202   T p_thresh_;
00203   T var_;
00204 };
00205 
00206 #endif