core/vnl/algo/vnl_lbfgs.cxx
Go to the documentation of this file.
00001 // This is core/vnl/algo/vnl_lbfgs.cxx
00002 #ifdef VCL_NEEDS_PRAGMA_INTERFACE
00003 #pragma implementation
00004 #endif
00005 //:
00006 // \file
00007 //
00008 // \author Andrew W. Fitzgibbon, Oxford RRG
00009 // \date   22 Aug 99
00010 //
00011 //-----------------------------------------------------------------------------
00012 
00013 #include "vnl_lbfgs.h"
00014 #include <vcl_cmath.h>
00015 #include <vcl_iostream.h>
00016 #include <vcl_iomanip.h> // for setw (replaces cout.form())
00017 
00018 #include <vnl/algo/vnl_netlib.h> // lbfgs_()
00019 
00020 //: Default constructor.
00021 // memory is set to 5, line_search_accuracy to 0.9.
00022 // Calls init_parameters
00023 vnl_lbfgs::vnl_lbfgs():
00024   f_(0)
00025 {
00026   init_parameters();
00027 }
00028 
00029 //: Constructor. f is the cost function to be minimized.
00030 // Calls init_parameters
00031 vnl_lbfgs::vnl_lbfgs(vnl_cost_function& f):
00032   f_(&f)
00033 {
00034   init_parameters();
00035 }
00036 
00037 //: Called by constructors.
00038 // Memory is set to 5, line_search_accuracy to 0.9, default_step_length to 1.
00039 void vnl_lbfgs::init_parameters()
00040 {
00041   memory = 5;
00042   line_search_accuracy = 0.9;
00043   default_step_length = 1.0;
00044 }
00045 
00046 bool vnl_lbfgs::minimize(vnl_vector<double>& x)
00047 {
00048   // Local variables
00049   // The driver for vnl_lbfgs must always declare LB2 as EXTERNAL
00050 
00051   long n = f_->get_number_of_unknowns();
00052   long m = memory; // The number of basis vectors to remember.
00053 
00054   // Create an instance of the lbfgs global data to pass as an
00055   // argument.  It must persist through all calls in this
00056   // minimization.
00057   v3p_netlib_lbfgs_global_t lbfgs_global;
00058   v3p_netlib_lbfgs_init(&lbfgs_global);
00059 
00060   long iprint[2] = {1, 0};
00061   vnl_vector<double> g(n);
00062 
00063   // Workspace
00064   vnl_vector<double> diag(n);
00065 
00066   vnl_vector<double> w(n * (2*m+1)+2*m);
00067 
00068   if (verbose_)
00069     vcl_cerr << "vnl_lbfgs: n = "<< n <<", memory = "<< m <<", Workspace = "
00070              << w.size() << "[ "<< ( w.size() / 128.0 / 1024.0) <<" MB], ErrorScale = "
00071              << f_->reported_error(1) <<", xnorm = "<< x.magnitude() << vcl_endl;
00072 
00073   bool we_trace = (verbose_ && !trace);
00074 
00075   if (we_trace)
00076     vcl_cerr << "vnl_lbfgs: ";
00077 
00078   double best_f = 0;
00079   vnl_vector<double> best_x;
00080 
00081   bool ok;
00082   this->num_evaluations_ = 0;
00083   this->num_iterations_ = 0;
00084   long iflag = 0;
00085   while (true) {
00086     // We do not wish to provide the diagonal matrices Hk0, and therefore set DIAGCO to FALSE.
00087     v3p_netlib_logical diagco = false;
00088 
00089     // Set these every iter in case user changes them to bail out
00090     double eps = gtol; // Gradient tolerance
00091     double local_xtol = 1e-16;
00092     lbfgs_global.gtol = line_search_accuracy; // set to 0.1 for huge problems or cheap functions
00093     lbfgs_global.stpinit = default_step_length;
00094 
00095     // Call function
00096     double f;
00097     f_->compute(x, &f, &g);
00098     if (this->num_evaluations_ == 0) {
00099       this->start_error_ = f;
00100       best_f = f;
00101     } else if (f < best_f) {
00102       best_x = x;
00103       best_f = f;
00104     }
00105 
00106 #define print_(i,a,b,c,d) vcl_cerr<<vcl_setw(6)<<i<<' '<<vcl_setw(20)<<a<<' '\
00107            <<vcl_setw(20)<<b<<' '<<vcl_setw(20)<<c<<' '<<vcl_setw(20)<<d<<'\n'
00108 
00109     if (check_derivatives_)
00110     {
00111       vcl_cerr << "vnl_lbfgs: f = " << f_->reported_error(f) << ", computing FD gradient\n";
00112       vnl_vector<double> fdg = f_->fdgradf(x);
00113       if (verbose_)
00114       {
00115         int l = n;
00116         int limit = 100;
00117         int limit_tail = 10;
00118         if (l > limit + limit_tail) {
00119           vcl_cerr << " [ Showing only first " <<limit<< " components ]\n";
00120           l = limit;
00121         }
00122         print_("i","x","g","fdg","dg");
00123         print_("-","-","-","---","--");
00124         for (int i = 0; i < l; ++i)
00125           print_(i, x[i], g[i], fdg[i], g[i]-fdg[i]);
00126         if (n > limit) {
00127           vcl_cerr << "   ...\n";
00128           for (int i = n - limit_tail; i < n; ++i)
00129             print_(i, x[i], g[i], fdg[i], g[i]-fdg[i]);
00130         }
00131       }
00132       vcl_cerr << "   ERROR = " << (fdg - g).squared_magnitude() / vcl_sqrt(double(n)) << "\n";
00133     }
00134 
00135     iprint[0] = trace ? 1 : -1; // -1 no o/p, 0 start and end, 1 every iter.
00136     iprint[1] = 0; // 1 prints X and G
00137     v3p_netlib_lbfgs_(
00138       &n, &m, x.data_block(), &f, g.data_block(), &diagco, diag.data_block(),
00139       iprint, &eps, &local_xtol, w.data_block(), &iflag, &lbfgs_global);
00140 
00141     this->report_eval(f);
00142 
00143     if (this->report_iter()) {
00144       failure_code_ = FAILED_USER_REQUEST;
00145       ok = false;
00146       x = best_x;
00147       break;
00148     }
00149 
00150     if (we_trace)
00151       vcl_cerr << iflag << ":" << f_->reported_error(f) << " ";
00152 
00153     if (iflag == 0) {
00154       // Successful return
00155       this->end_error_ = f;
00156       ok = true;
00157       x = best_x;
00158       break;
00159     }
00160 
00161     if (iflag < 0) {
00162       // Netlib routine lbfgs failed
00163       vcl_cerr << "vnl_lbfgs: Error. Netlib routine lbfgs failed.\n";
00164       ok = false;
00165       x = best_x;
00166       break;
00167     }
00168 
00169     if (this->num_evaluations_ > get_max_function_evals()) {
00170       failure_code_ = TOO_MANY_ITERATIONS;
00171       ok = false;
00172       x = best_x;
00173       break;
00174     }
00175 
00176   }
00177   if (we_trace) vcl_cerr << "done\n";
00178 
00179   return ok;
00180 }