contrib/brl/bbas/bsta/bsta_beta.txx
Go to the documentation of this file.
00001 // This is brl/bbas/bsta/bsta_beta.txx
00002 #ifndef bsta_beta_txx_
00003 #define bsta_beta_txx_
00004 //:
00005 // \file
00006 
00007 #include "bsta_beta.h"
00008 #include <vnl/algo/vnl_rpoly_roots.h>
00009 #include <vnl/vnl_beta.h>
00010 #include <vnl/vnl_bignum.h>
00011 
00012 // Factorial
00013 static inline vnl_bignum factorial(int n)
00014 {
00015   if (n <= 1)
00016     return vnl_bignum(1);
00017   else
00018     return n * factorial(n-1);
00019 }
00020 
00021 
00022 //: constructs from a set of sample values
00023 template <class T>
00024 bsta_beta<T>::bsta_beta(vcl_vector<T> x)
00025 {
00026   T mean=0;
00027   T var=0;
00028 
00029   for (unsigned i=0; i<x.size(); i++) {
00030     mean+=x[i];
00031   }
00032 
00033   mean/=x.size();
00034 
00035   for (unsigned i=0; i<x.size(); i++) {
00036     T diff = x[i]-mean;
00037     var+=diff*diff;
00038   }
00039 
00040   var/=x.size();
00041 
00042   T t = (mean*(1-mean)/var)-1;
00043   alpha_=mean*t;
00044   beta_=(1-mean)*t;
00045 }
00046 
00047 template <class T>
00048 bool bsta_beta<T>::bsta_beta_from_moments(T mean, T var, T& alpha, T& beta)
00049 {
00050   if (var == 0)
00051     return false;
00052 
00053   int flag_special=0;
00054 
00055   if (mean<0.5)
00056   {
00057     if (mean*mean*mean-mean*mean+mean*var+var>0)
00058       flag_special=1;
00059   }
00060   else
00061   {
00062     if (mean*mean*mean-2*mean*mean+mean*(1+var)-2*var<0)
00063       flag_special=2;
00064   }
00065 
00066   if (flag_special==1)
00067   {
00068     alpha=T(1);
00069     vnl_vector<double> pts(4);
00070     pts[0]=1;pts[1]=3;pts[2]=4-1/var;pts[3]=2;
00071     vnl_rpoly_roots r(pts);
00072     vnl_vector<double> roots=r.realroots(0.1);
00073     bool flag=false;
00074     for (unsigned i=0;i<roots.size();i++)
00075     {
00076       if (roots[i]>0)
00077       {
00078         flag=true;
00079         beta=T(roots[i]);
00080       }
00081     }
00082     return flag;
00083   }
00084   else if (flag_special==2)
00085   {
00086     beta=T(1);
00087     vnl_vector<double> pts(4);
00088     pts[0]=1;pts[1]=3;pts[2]=4-1/var;pts[3]=2;
00089     vnl_rpoly_roots r(pts);
00090     vnl_vector<double> roots=r.realroots(0.1);
00091 
00092     bool flag=false;
00093     for (unsigned i=0;i<roots.size();i++)
00094     {
00095       if (roots[i]>0)
00096       {
00097         flag=true;
00098         alpha=T(roots[i]);
00099       }
00100     }
00101     return flag;
00102   }
00103   else
00104   {
00105     T t = mean*(1-mean)/var-1;
00106     alpha=mean*t;
00107     beta=(1-mean)*t;
00108   }
00109 #if 0 // commented out ...
00110   T det=vcl_sqrt(1-12*var);
00111   if (mean<=(1-det)/2)
00112   {
00113     alpha=T(1);
00114     vnl_vector<double> pts(4);
00115     pts[0]=1;pts[1]=3;pts[2]=4-1/var;pts[3]=2;
00116     vnl_rpoly_roots r(pts);
00117     vnl_vector<double> roots=r.realroots(0.1);
00118 
00119     bool flag=false;
00120     for (unsigned i=0;i<roots.size();i++)
00121     {
00122       if (roots[i]>0)
00123       {
00124         flag=true;
00125         beta=T(roots[i]);
00126       }
00127     }
00128     return flag;
00129   }
00130   else if (mean>=(1+det)/2)
00131   {
00132     beta=T(1);
00133     vnl_vector<double> pts(4);
00134     pts[0]=1;pts[1]=3;pts[2]=4-1/var;pts[3]=2;
00135     vnl_rpoly_roots r(pts);
00136     vnl_vector<double> roots=r.realroots(0.1);
00137 
00138     bool flag=false;
00139     for (unsigned i=0;i<roots.size();i++)
00140     {
00141       if (roots[i]>0)
00142       {
00143         flag=true;
00144         alpha=T(roots[i]);
00145       }
00146     }
00147     return flag;
00148   }
00149   else
00150   {
00151     T t = mean*(1-mean)/var-1;
00152     alpha=mean*t;
00153     beta=(1-mean)*t;
00154   }
00155 #endif // 0
00156   return true;
00157 }
00158 
00159 //:
00160 // pre: x should be in [0,1]
00161 // Otherwise, zero is returned.
00162 template <class T>
00163 T bsta_beta<T>::prob_density(T x) const
00164 {
00165   if (x==0.0)
00166     x+=T(1e-10);
00167   else if (x==1.0)
00168     x-=T(1e-10);
00169   if (x<T(0)||x>T(1))
00170     return 0;
00171   else
00172   {
00173     double a = vnl_log_beta(alpha_,beta_);
00174     double b = (alpha_-1)*vcl_log(x);
00175     double c = (beta_-1)*vcl_log(1-x);
00176 
00177     if (b+c-a<-60)
00178       return T(0);
00179     else if (b+c-a>60)
00180       return T(100);
00181     else
00182       return (T)vcl_exp(b+c-a);
00183   }
00184 }
00185 
00186 template <class T>
00187 T bsta_beta<T>::distance(T x) const
00188 {
00189   T mean =alpha_/(alpha_+beta_);
00190   if (x==0 && alpha_==1)
00191     return (T)((beta_-1)*vcl_log((1-x)/(1-mean)));
00192   else if (x==1 && beta_==1)
00193     return (T)((alpha_-1)*vcl_log(x/mean));
00194   else
00195     return (T)((alpha_-1)*vcl_log(x/mean)+(beta_-1)*vcl_log((1-x)/(1-mean)));
00196 }
00197 
00198 // cumulative distribution function
00199 template <class T>
00200 T bsta_beta<T>::cum_dist_funct(T x) const
00201 {
00202   unsigned a = static_cast<unsigned>(alpha_);
00203   unsigned b = static_cast<unsigned>(beta_);
00204   T Ix=T(0);
00205   T val;
00206   for (unsigned j=a; j<=a+b-1; j++) {
00207     val = factorial(a+b-1)/(factorial(j)*factorial(a+b-1-j));
00208     val *= vcl_pow(x,T(j))*vcl_pow(1-x, T(a+b-1-j));
00209     Ix+=val;
00210   }
00211   return Ix;
00212 }
00213 
00214 #undef BSTA_BETA_INSTANTIATE
00215 #define BSTA_BETA_INSTANTIATE(T) \
00216 template class bsta_beta<T >
00217 
00218 #endif