contrib/rpl/rrel/rrel_irls.cxx
Go to the documentation of this file.
00001 // This is rpl/rrel/rrel_irls.cxx
00002 #include "rrel_irls.h"
00003 //:
00004 // \file
00005 
00006 #include <rrel/rrel_estimation_problem.h>
00007 #include <rrel/rrel_wls_obj.h>
00008 #include <rrel/rrel_util.h>
00009 
00010 #include <vnl/vnl_math.h>
00011 #include <vnl/vnl_vector.h>
00012 #include <vnl/vnl_matrix.h>
00013 
00014 #include <vcl_iostream.h>
00015 #include <vcl_vector.h>
00016 #include <vcl_cassert.h>
00017 
00018 const double rrel_irls::dflt_convergence_tol_ = 1e-4;
00019 const int rrel_irls::dflt_max_iterations_ = 25;
00020 const int rrel_irls::dflt_iterations_for_scale_ = 1;
00021 
00022 // -------------------------------------------------------------------------
00023 rrel_irls::rrel_irls( int max_iterations )
00024   : max_iterations_(max_iterations), test_converge_(true),
00025     convergence_tol_(dflt_convergence_tol_), est_scale_during_(true),
00026     use_weighted_scale_(true),
00027     iterations_for_scale_est_(dflt_iterations_for_scale_),
00028     scale_lower_bound_( -1.0 ),
00029     trace_level_(0), params_initialized_(false), scale_initialized_(false),
00030     obj_fcn_(1e256), prev_obj_fcn_(1e256),
00031     converged_(false), iteration_(0)
00032 {
00033   assert( max_iterations > 0 );
00034 }
00035 
00036 // -------------------------------------------------------------------------
00037 void
00038 rrel_irls::set_est_scale( int iterations_for_scale_est,
00039                           bool use_weighted_scale )
00040 {
00041   est_scale_during_ = true;
00042   use_weighted_scale_ = use_weighted_scale;
00043   iterations_for_scale_est_ = iterations_for_scale_est;
00044   if ( iterations_for_scale_est_ < 0 )
00045     vcl_cerr << "rrel_irls::est_scale_during WARNING last_scale_est_iter is\n"
00046              << "negative, so scale will not be estimated!\n";
00047 }
00048 
00049 // -------------------------------------------------------------------------
00050 //: Set lower bound of scale estimate
00051 void
00052 rrel_irls::set_scale_lower_bound( double lower_scale )
00053 {
00054   scale_lower_bound_ = lower_scale;
00055 }
00056 
00057 // -------------------------------------------------------------------------
00058 void
00059 rrel_irls::set_no_scale_est()
00060 {
00061   est_scale_during_ = false;
00062 }
00063 
00064 // -------------------------------------------------------------------------
00065 void
00066 rrel_irls::initialize_scale( double scale )
00067 {
00068   scale_ = scale;
00069   scale_initialized_ = true;
00070 }
00071 
00072 // -------------------------------------------------------------------------
00073 double
00074 rrel_irls::scale() const
00075 {
00076   assert( scale_initialized_ );
00077   return scale_;
00078 }
00079 
00080 
00081 // -------------------------------------------------------------------------
00082 void
00083 rrel_irls::set_max_iterations( int max_iterations )
00084 {
00085   max_iterations_ = max_iterations;
00086   assert( max_iterations_ > 0 );
00087 }
00088 
00089 
00090 // -------------------------------------------------------------------------
00091 void
00092 rrel_irls::set_convergence_test( double convergence_tol )
00093 {
00094   test_converge_ = true;
00095   convergence_tol_ = convergence_tol;
00096   assert( convergence_tol_ > 0 );
00097 }
00098 
00099 
00100 // -------------------------------------------------------------------------
00101 void
00102 rrel_irls::set_no_convergence_test( )
00103 {
00104   test_converge_ = false;
00105 }
00106 
00107 
00108 // -------------------------------------------------------------------------
00109 void
00110 rrel_irls::initialize_params( const vnl_vector<double>& init_params )
00111 {
00112   params_ = init_params;
00113   params_initialized_ = true;
00114 }
00115 
00116 
00117 bool
00118 rrel_irls::estimate( const rrel_estimation_problem* problem,
00119                      const rrel_wls_obj* obj )
00120 {
00121   iteration_ = 0;
00122   obj_fcn_ = 1e256;
00123   unsigned int num_for_fit = problem->num_samples_to_instantiate();
00124   bool allow_convergence_test = true;
00125   vcl_vector<double> residuals( problem->num_samples() );
00126   vcl_vector<double> weights( problem->num_samples() );
00127   bool failed = false;
00128 
00129   //  Parameter initialization, if necessary
00130   if ( ! params_initialized_ )
00131   {
00132     if ( ! problem->weighted_least_squares_fit( params_, cofact_ ) )
00133       return false;
00134     allow_convergence_test = false;
00135     params_initialized_ = true;
00136   }
00137 
00138 
00139   //  Scale initialization, if necessary
00140   if ( obj->requires_prior_scale() && problem->scale_type() == rrel_estimation_problem::NONE ) {
00141     vcl_cerr << "irls::estimate: Objective function requires a prior scale, and the problem does not provide one.\n"
00142              << "                Aborting estimation.\n";
00143     return false;
00144   } else {
00145     if ( problem->scale_type() == rrel_estimation_problem::NONE && ! scale_initialized_ ) {
00146       problem->compute_residuals( params_, residuals );
00147       scale_ = rrel_util_median_abs_dev_scale( residuals.begin(), residuals.end(), num_for_fit );
00148       allow_convergence_test = false;
00149       scale_initialized_ = true;
00150     }
00151   }
00152 
00153   if ( trace_level_ >= 1 )
00154     vcl_cout << "Initial estimate: " << params_ << ", scale = " << scale_ <<  vcl_endl;
00155 
00156   assert( params_initialized_ && scale_initialized_ );
00157   if ( scale_ <= 1e-8 ) {
00158     unsigned int dof = problem-> param_dof();
00159     cofact_ = 1e-16 * vnl_matrix<double>(dof, dof, vnl_matrix_identity);
00160     scale_ = 0.0;
00161     converged_ = true;
00162     vcl_cerr << "rrel_irls::estimate: initial scale is zero - cannot estimate\n";
00163     // usually, This means that it already has an exact fitting.
00164     // Thus, no harm if return true
00165     return true;
00166   }
00167 
00168 
00169   //  Basic loop:
00170   //  1. Calculate residuals
00171   //  2. Test for convergence, if desired.
00172   //  3. Calculate weights
00173   //  4. Calculate scale
00174   //  5. Calculate new estimate
00175   //
00176 
00177   converged_ = false;
00178   while ( true ) {
00179     //  Step 1.  Residuals
00180     problem->compute_residuals( params_, residuals );
00181     if ( trace_level_ >= 2 ) trace_residuals( residuals );
00182 
00183     //  Step 2.  Convergence.  The allow_convergence_test parameter
00184     //  prevents use of the convergence test until after the
00185     //  iterations involving scale estimation are finished.
00186     if ( test_converge_ && allow_convergence_test &&
00187          has_converged( residuals, obj, problem, &params_ ) ) {
00188       converged_ = true;
00189       break;
00190     }
00191     ++ iteration_;
00192     if ( iteration_ > max_iterations_ ) break;
00193     if ( trace_level_ >= 1 ) vcl_cout << "\nIteration: " << iteration_ << '\n';
00194 
00195     //  Step 3. Weights
00196     problem->compute_weights( residuals, obj, scale_, weights );
00197     if ( trace_level_ >= 2 ) trace_weights( weights );
00198 
00199     //  Step 4.  Scale.  Note: the residuals are reordered and therefore useless after
00200     //  rrel_util_median_abs_dev_scale.
00201     if ( est_scale_during_ && iteration_ <= iterations_for_scale_est_ ) {
00202       allow_convergence_test = false;
00203       if ( trace_level_ >= 1 ) vcl_cout << "num samples for fit = " << num_for_fit << '\n';
00204       if ( use_weighted_scale_ ) {
00205         assert( residuals.size() == weights.size() );
00206         scale_ = rrel_util_weighted_scale( residuals.begin(), residuals.end(),
00207                                            weights.begin(), num_for_fit, (double*)0 );
00208       }
00209       else {
00210         scale_ = rrel_util_median_abs_dev_scale( residuals.begin(), residuals.end(), num_for_fit );
00211       }
00212       if ( trace_level_ >= 1 ) vcl_cout << "Scale estimated: " << scale_ << vcl_endl;
00213       if ( scale_ <= 1e-8 ) {  //  fit exact enough to yield 0 scale estimate
00214         unsigned int dof = problem-> param_dof();
00215         cofact_ = 1e-16 * vnl_matrix<double>(dof, dof, vnl_matrix_identity);
00216         scale_ = 0.0;
00217         converged_ = true;
00218         vcl_cerr << "rrel_irls::estimate:  scale has gone to 0.\n";
00219         break;
00220       }
00221 
00222       // check lower bound
00223       if ( scale_lower_bound_ > 0 && scale_ < scale_lower_bound_ )
00224         scale_ = scale_lower_bound_;
00225     }
00226     else
00227       allow_convergence_test = true;
00228 
00229     // Step 5.  Weighted least-squares
00230     if ( !problem->weighted_least_squares_fit( params_, cofact_, &weights ) ) {
00231       failed = true;
00232       break;
00233     }
00234     if ( trace_level_ >= 1 ) vcl_cout << "Fit: " << params_ << vcl_endl;
00235   }
00236 
00237   return !failed;
00238 }
00239 
00240 
00241 // -------------------------------------------------------------------------
00242 const vnl_vector<double>&
00243 rrel_irls::params() const
00244 {
00245   assert( params_initialized_ );
00246   return params_;
00247 }
00248 
00249 
00250 // -------------------------------------------------------------------------
00251 const vnl_matrix<double>&
00252 rrel_irls::cofactor() const
00253 {
00254   assert( params_initialized_ );
00255   return cofact_;
00256 }
00257 
00258 
00259 // -------------------------------------------------------------------------
00260 int
00261 rrel_irls::iterations_used() const
00262 {
00263   return iteration_-1;
00264 }
00265 
00266 
00267 // -------------------------------------------------------------------------
00268 bool
00269 rrel_irls::has_converged( const vcl_vector<double>& residuals,
00270                           const rrel_wls_obj* obj,
00271                           const rrel_estimation_problem* problem,
00272                           vnl_vector<double>* params )
00273 {
00274   prev_obj_fcn_ = obj_fcn_;
00275   switch ( problem->scale_type() )
00276   {
00277    case rrel_estimation_problem::NONE:
00278     obj_fcn_ = obj->fcn( residuals.begin(), residuals.end(), scale_, params );
00279     break;
00280    case rrel_estimation_problem::SINGLE:
00281     obj_fcn_ = obj->fcn( residuals.begin(), residuals.end(), problem->prior_scale(), params );
00282     break;
00283    case rrel_estimation_problem::MULTIPLE:
00284     obj_fcn_ = obj->fcn( residuals.begin(), residuals.end(), problem->prior_multiple_scales().begin(), params );
00285     break;
00286    default:
00287     assert(!"invalid scale_type");
00288   }
00289 
00290   if ( trace_level_ >= 1 )
00291     vcl_cout << "  prev obj fcn = " << prev_obj_fcn_
00292              << ",  new obj fcn = " << obj_fcn_ << vcl_endl;
00293 
00294   return vnl_math_abs( obj_fcn_ ) < convergence_tol_  ||
00295     vnl_math_abs(obj_fcn_ - prev_obj_fcn_) < convergence_tol_ * obj_fcn_;
00296 }
00297 
00298 
00299 // -------------------------------------------------------------------------
00300 void
00301 rrel_irls::trace_residuals( const vcl_vector<double>& residuals ) const
00302 {
00303   vcl_cout << "Residuals:\n";
00304   for ( unsigned int i=0; i<residuals.size(); ++i )
00305     vcl_cout << "  " << i << ": " << residuals[i] << '\n';
00306 }
00307 
00308 
00309 // -------------------------------------------------------------------------
00310 void
00311 rrel_irls::trace_weights( const vcl_vector<double>& weights ) const
00312 {
00313   vcl_cout << "Weights:\n";
00314   for ( unsigned int i=0; i<weights.size(); ++i )
00315     vcl_cout << "  " << i << ": " << weights[i] << '\n';
00316 }
00317 
00318 
00319 // -------------------------------------------------------------------------
00320 void
00321 rrel_irls::print_params() const
00322 {
00323   vcl_cout << "  max_iterations_ = " << max_iterations_ << '\n'
00324            << "  test_converge_ = " << test_converge_ << '\n'
00325            << "  convergence_tol_ = " << convergence_tol_ << '\n'
00326            << "  est_scale_during_ = " << est_scale_during_ << '\n'
00327            << "  iterations_for_scale_est_ = " << iterations_for_scale_est_
00328            << vcl_endl;
00329 }