contrib/brl/bbas/bsta/bsta_mixture.h
Go to the documentation of this file.
00001 // This is brl/bbas/bsta/bsta_mixture.h
00002 #ifndef bsta_mixture_h_
00003 #define bsta_mixture_h_
00004 //:
00005 // \file
00006 // \brief A mixture of distributions
00007 // \author Matt Leotta (mleotta@lems.brown.edu)
00008 // \date January 26, 2006
00009 //
00010 // \verbatim
00011 //  Modifications
00012 //   Jan 21 2008  -  Matt Leotta  -  Rename probability to prob_density and
00013 //                                   add probability integration over a box
00014 // \endverbatim
00015 
00016 #include "bsta_distribution.h"
00017 #include <vcl_cassert.h>
00018 #include <vcl_vector.h>
00019 #include <vcl_algorithm.h>
00020 #include <vcl_iostream.h>
00021 #include "bsta_sampler.h"
00022 #include <vpdl/vpdt/vpdt_dist_traits.h>
00023 #include <vnl/vnl_random.h>
00024 
00025 //: A mixture of distributions
00026 template <class dist_>
00027 class bsta_mixture : public bsta_distribution<typename dist_::math_type,
00028                                                        dist_::dimension>
00029 {
00030  public:
00031   typedef dist_ dist_type;
00032   typedef dist_type component_type; // for compatibility with vpdl/vpdt
00033   // unlimited number of component is indicated by 0
00034   enum { max_components = 0 };
00035 
00036  private:
00037   typedef typename dist_::math_type T;
00038   typedef typename dist_::vector_type vector_;
00039 
00040   //: A struct to hold the component distributions and weights
00041   // This class is private and should not be used outside of the mixture.
00042   // Dynamic memory is used to allow for polymorphic distributions.
00043   // However, this use of memory is self-contained and private so the user
00044   // should not be able to introduce a memory leak
00045   struct component
00046   {
00047     //: Constructor
00048     component(): distribution(), weight(T(0)) {}
00049     //: Constructor
00050     component(const dist_& d, const T& w = T(0) )
00051       : distribution(d), weight(w) {}
00052 
00053     //: Used to sort by decreasing weight
00054     bool operator< (const component& rhs) const
00055     { return this->weight > rhs.weight; }
00056 
00057     // ============ Data =============
00058 
00059     //: The distribution
00060     dist_ distribution;
00061     //: The weight
00062     T weight;
00063   };
00064 
00065   //: This functor is used by default for sorting with STL
00066   // The default sorting is decreasing by weight
00067   class sort_weight
00068   {
00069    public:
00070     bool operator() (const component* c1, const component* c2) const
00071       { return c1->weight > c2->weight; }
00072   };
00073 
00074   //: This adaptor allows users to define ordering functors on the components without accessing the components directly
00075   template <class comp_type_>
00076   class sort_adaptor
00077   {
00078    public:
00079     sort_adaptor(comp_type_ c) : comp(c) {}
00080     bool operator() (const component* const c1, const component* const c2) const
00081     { return comp(c1->distribution, c1->weight, c2->distribution, c2->weight); }
00082     comp_type_ comp;
00083   };
00084 
00085   //: The vector of components
00086   vcl_vector<component*> components_;
00087 
00088  public:
00089   // Default Constructor
00090   bsta_mixture<dist_>() {}
00091 
00092   // Copy Constructor
00093   bsta_mixture<dist_>(const bsta_mixture<dist_>& other)
00094     : components_(other.components_.size(),NULL)
00095   {
00096     // deep copy of the data
00097     for (unsigned int i=0; i<components_.size(); ++i){
00098       components_[i] = new component(*other.components_[i]);
00099     }
00100   }
00101 
00102   // Destructor
00103   ~bsta_mixture<dist_>()
00104   {
00105     for (unsigned int i=0; i<components_.size(); ++i){
00106       delete components_[i];
00107     }
00108   }
00109 
00110   //: Assignment operator
00111   bsta_mixture<dist_>& operator= (const bsta_mixture<dist_>& rhs)
00112   {
00113     if (this != &rhs){
00114       for (unsigned int i=0; i<components_.size(); ++i){
00115         delete components_[i];
00116       }
00117       components_.clear();
00118       for (unsigned int i=0; i<rhs.components_.size(); ++i){
00119         components_.push_back(new component(*rhs.components_[i]));
00120       }
00121     }
00122     return *this;
00123   }
00124 
00125   //: Return the number of components in the mixture
00126   unsigned int num_components() const { return components_.size(); }
00127 
00128   //: Access (const) a component distribution of the mixture
00129   const dist_& distribution(unsigned int index) const
00130   { return components_[index]->distribution; }
00131 
00132   //: Access a component distribution of the mixture
00133   dist_& distribution(unsigned int index)
00134   { return components_[index]->distribution; }
00135 
00136   //: Return the weight of a component in the mixture
00137   T weight(unsigned int index) const { return components_[index]->weight; }
00138 
00139   //: Set the weight of a component in the mixture
00140   void set_weight(unsigned int index, const T& w)
00141   { components_[index]->weight = w; }
00142 
00143   //: Insert a new component at the end of the vector
00144   bool insert(const dist_& d, const T& weight = T(0))
00145   { components_.push_back(new component(d, weight)); return true; }
00146 
00147   //: Remove the last component in the vector
00148   void remove_last() { delete components_.back(); components_.pop_back(); }
00149 
00150   //: Compute the probability density at this point
00151   // \note assumes weights have been normalized
00152   T prob_density(const vector_& pt) const
00153   {
00154     typedef typename vcl_vector<component*>::const_iterator comp_itr;
00155     T prob = 0;
00156     for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00157       prob += (*i)->weight * (*i)->distribution.prob_density(pt);
00158     return prob;
00159   }
00160 
00161   //: The probability integrated over a box
00162   // \note assumes weights have been normalized
00163   T probability(const vector_& min_pt, const vector_& max_pt) const
00164   {
00165     typedef typename vcl_vector<component*>::const_iterator comp_itr;
00166     T prob = 0;
00167     for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00168       prob += (*i)->weight * (*i)->distribution.probability(min_pt,max_pt);
00169     return prob;
00170   }
00171 
00172   //: Normalize the weights of the components to add to 1.
00173   void normalize_weights()
00174   {
00175     typedef typename vcl_vector<component*>::iterator comp_itr;
00176     T sum = 0;
00177     for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00178       sum += (*i)->weight;
00179     assert(sum > 0);
00180     for (comp_itr i = components_.begin(); i != components_.end(); ++i)
00181       (*i)->weight /= sum;
00182   }
00183 
00184   //: Sort the components in order of decreasing weight
00185   void sort() { vcl_sort(components_.begin(), components_.end(), sort_weight() ); }
00186 
00187   //: Sort the components using any StrictWeakOrdering function
00188   // The prototype should be
00189   // \code
00190   // template <class T>
00191   // bool functor(const bsta_distribution<T>& d1, const T& w1,
00192   //              const bsta_distribution<T>& d2, const T& w2);
00193   // \endcode
00194   template <class comp_type_>
00195   void sort(comp_type_ comp)
00196   { vcl_sort(components_.begin(), components_.end(), sort_adaptor<comp_type_>(comp)); }
00197 
00198   template <class comp_type_>
00199   void sort(comp_type_ comp, unsigned int idx)
00200   { vcl_sort(components_.begin(), components_.begin()+idx+1, sort_adaptor<comp_type_>(comp)); }
00201 
00202   //: sample from the mixture
00203   //  randomly selects a component wrt normalized component weights, then for now returns the mean of the selected component
00204   //  \todo write a method to sample from the distribution and use it instead of the mean
00205   vector_ sample(vnl_random& rng) const {
00206     //: first normalize the weights (this is const methods so we cannot call the class-method normalize_weights()
00207     T sum = 0;
00208     for (unsigned i=0; i<num_components(); ++i)
00209       sum += components_[i].weight;
00210 
00211     vcl_vector<float> ps;
00212     vcl_vector<unsigned> ids;
00213     for (unsigned i=0; i<num_components(); ++i) {
00214       float w;
00215       if (sum > 0)
00216         w = float(components_[i].weight/sum);
00217       else
00218         w = float(components_[i].weight);
00219       ps.push_back(w);
00220       ids.push_back(i);
00221     }
00222     vcl_vector<unsigned> out;
00223     bsta_sampler<unsigned>::sample(ids, ps, 1, out);
00224     assert(out.size() == 1);
00225 
00226     return components_[out[0]].distribution.sample(rng);
00227     //return components_[out[0]].distribution.mean();
00228   }
00229 };
00230 
00231 template <class dist_>
00232 inline vcl_ostream& operator<< (vcl_ostream& os,
00233                                 bsta_mixture<dist_> const& m)
00234 {
00235   typedef typename dist_::math_type T;
00236   unsigned n = m.num_components();
00237   for (unsigned c = 0; c<n; ++c){
00238     const dist_& mc = m.distribution(c);
00239     T weight = m.weight(c);
00240     os << "mixture_comp["<< c << "]wgt(" << weight << ")\n" << mc << '\n';
00241   }
00242   return os;
00243 }
00244 
00245 //: for compatibility with vpdl/vpdt
00246 template <class dist>
00247 struct vpdt_is_mixture<bsta_mixture<dist> >
00248 {
00249   static const bool value = true;
00250 };
00251 
00252 #endif // bsta_mixture_h_