00001 #include "vnl_solve_qp.h"
00002
00003
00004
00005
00006
00007 #include <vnl/algo/vnl_svd.h>
00008 #include <vnl/algo/vnl_cholesky.h>
00009 #include <vcl_vector.h>
00010 #include <vcl_cassert.h>
00011 #include <vcl_iostream.h>
00012
00013
00014 static void vnl_solve_symmetric_le(const vnl_matrix<double>& S,
00015 const vnl_vector<double>& b,
00016 vnl_vector<double>& x)
00017 {
00018 vnl_cholesky chol(S,vnl_cholesky::estimate_condition);
00019 if (chol.rcond()>1e-8) x=chol.solve(b);
00020 else
00021 {
00022 vnl_svd<double> svd(S);
00023 x=svd.solve(b);
00024 }
00025 }
00026
00027
00028
00029
00030
00031 bool vnl_solve_qp_with_equality_constraints(const vnl_matrix<double>& H,
00032 const vnl_vector<double>& g,
00033 const vnl_matrix<double>& A,
00034 const vnl_vector<double>& b,
00035 vnl_vector<double>& x)
00036 {
00037
00038
00039 unsigned nc=A.rows();
00040 assert(H.cols()==H.rows());
00041 assert(g.size()==H.rows());
00042 assert(A.cols()==H.rows());
00043 assert(b.size()==nc);
00044
00045 vnl_matrix<double> H_inv;
00046 vnl_cholesky Hchol(H,vnl_cholesky::estimate_condition);
00047 if (Hchol.rcond()>1e-8) H_inv=Hchol.inverse();
00048 else
00049 {
00050 vnl_svd<double> Hsvd(H);
00051 H_inv=Hsvd.inverse();
00052 }
00053
00054 if (nc==0)
00055 {
00056
00057 x=-1.0*H_inv*g;
00058 return true;
00059 }
00060
00061 vnl_vector<double> b1=(b+A*H_inv*g)*-1.0;
00062
00063
00064 vnl_vector<double> lambda;
00065 vnl_matrix<double> AHA = A*H_inv*A.transpose();
00066 vnl_solve_symmetric_le(AHA,b1,lambda);
00067
00068 x=(H_inv*(g+A.transpose()*lambda))*-1.0;
00069 return true;
00070 }
00071
00072
00073
00074
00075
00076
00077 bool vnl_solve_qp_zero_sum(const vnl_matrix<double>& H,
00078 const vnl_vector<double>& g,
00079 vnl_vector<double>& x)
00080 {
00081
00082
00083 assert(H.cols()==H.rows());
00084 assert(g.size()==H.rows());
00085
00086 vnl_matrix<double> H_inv;
00087 vnl_cholesky Hchol(H,vnl_cholesky::estimate_condition);
00088 if (Hchol.rcond()>1e-8) H_inv=Hchol.inverse();
00089 else
00090 {
00091 vnl_svd<double> Hsvd(H);
00092 H_inv=Hsvd.inverse();
00093 }
00094
00095 double b1=-1.0*(H_inv*g).sum();
00096
00097
00098 double H_inv_sum = vnl_c_vector<double>::sum(H_inv.begin(),H_inv.size());
00099
00100 if (vcl_fabs(H_inv_sum)<1e-8)
00101 {
00102 vcl_cerr<<"Uh-oh. H_inv.sum()="<<H_inv_sum<<vcl_endl
00103 <<"H="<<H<<vcl_endl
00104 <<"H_inv="<<H_inv<<vcl_endl;
00105 }
00106
00107
00108 double lambda = b1/H_inv_sum;
00109
00110 vnl_vector<double> g1(g);
00111 g1+=lambda;
00112
00113 x=(H_inv*g1);
00114 x*=-1.0;
00115
00116 return true;
00117 }
00118
00119
00120 static bool vnl_solve_qp_update_x(vnl_vector<double>& x,
00121 const vnl_vector<double>& x1,
00122 vnl_vector<double>& dx,
00123 vcl_vector<bool>& valid,
00124 unsigned& n_valid)
00125 {
00126 unsigned n=x.size();
00127
00128 int worst_i=-1;
00129 double min_alpha=1.0;
00130 for (unsigned i=0;i<n_valid;++i)
00131 {
00132 if (dx[i]<0.0)
00133 {
00134 double alpha = -1.0*x1[i]/dx[i];
00135 if (alpha<min_alpha)
00136 {
00137 min_alpha=alpha; worst_i=i;
00138 }
00139 }
00140 }
00141
00142
00143 unsigned i1=0;
00144 for (unsigned i=0;i<n;++i)
00145 {
00146 if (valid[i])
00147 {
00148 x[i]+=min_alpha*dx[i1];
00149 if (i1==(unsigned int)worst_i)
00150 {
00151
00152 x[i]=0.0;
00153 valid[i]=false;
00154 n_valid--;
00155 }
00156 ++i1;
00157 }
00158 }
00159
00160 return worst_i<0;
00161 }
00162
00163
00164
00165
00166 bool vnl_solve_qp_non_neg_step(const vnl_matrix<double>& H,
00167 const vnl_vector<double>& g,
00168 const vnl_matrix<double>& A,
00169 const vnl_vector<double>& b,
00170 vnl_vector<double>& x,
00171 vcl_vector<bool>& valid,
00172 unsigned& n_valid)
00173 {
00174
00175
00176
00177
00178 unsigned n=H.rows();
00179 unsigned nc=A.rows();
00180
00181 vnl_matrix<double> H1(n_valid,n_valid);
00182 vnl_matrix<double> A1(nc,n_valid);
00183 unsigned j1=0;
00184 for (unsigned j=0;j<n;++j)
00185 {
00186 if (valid[j])
00187 {
00188
00189
00190 unsigned i1=0;
00191 for (unsigned i=0;i<n;++i)
00192 {
00193 if (valid[i]) { H1(i1,j1)=H(i,j); ++i1; }
00194 }
00195
00196
00197 for (unsigned i=0;i<nc;++i,++i1) A1(i,j1)=A(i,j);
00198
00199 ++j1;
00200 }
00201 }
00202
00203 vnl_vector<double> x1(n_valid);
00204 vnl_vector<double> g1(n_valid);
00205 unsigned i1=0;
00206 for (unsigned i=0;i<n;++i)
00207 {
00208 if (valid[i]) { g1[i1]=g[i]; x1[i1]=x[i]; ++i1; }
00209 }
00210 g1 += H1*x1;
00211
00212 vnl_vector<double> b1(b);
00213 b1-= A1*x1;
00214
00215 vnl_vector<double> dx(n_valid,0.0);
00216
00217 vnl_solve_qp_with_equality_constraints(H1,g1,A1,b1,dx);
00218
00219
00220 return vnl_solve_qp_update_x(x,x1,dx,valid,n_valid);
00221 }
00222
00223
00224
00225
00226 bool vnl_solve_qp_non_neg_sum_one_step(const vnl_matrix<double>& H,
00227 const vnl_vector<double>& g,
00228 vnl_vector<double>& x,
00229 vcl_vector<bool>& valid,
00230 unsigned& n_valid)
00231 {
00232
00233
00234
00235
00236 unsigned n=H.rows();
00237
00238 vnl_matrix<double> H1(n_valid,n_valid);
00239 unsigned j1=0;
00240 for (unsigned j=0;j<n;++j)
00241 {
00242 if (valid[j])
00243 {
00244
00245
00246 unsigned i1=0;
00247 for (unsigned i=0;i<n;++i)
00248 {
00249 if (valid[i]) { H1(i1,j1)=H(i,j); ++i1; }
00250 }
00251 ++j1;
00252 }
00253 }
00254
00255 vnl_vector<double> x1(n_valid);
00256 vnl_vector<double> g1(n_valid);
00257 unsigned i1=0;
00258 for (unsigned i=0;i<n;++i)
00259 {
00260 if (valid[i]) { g1[i1]=g[i]; x1[i1]=x[i]; ++i1; }
00261 }
00262 g1 += H1*x1;
00263
00264 vnl_vector<double> dx(n_valid,0.0);
00265
00266 vnl_solve_qp_zero_sum(H1,g1,dx);
00267
00268
00269 return vnl_solve_qp_update_x(x,x1,dx,valid,n_valid);
00270 }
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285 bool vnl_solve_qp_with_non_neg_constraints(const vnl_matrix<double>& H,
00286 const vnl_vector<double>& g,
00287 const vnl_matrix<double>& A,
00288 const vnl_vector<double>& b,
00289 vnl_vector<double>& x,
00290 double con_tol,
00291 bool verbose)
00292 {
00293
00294 unsigned n=H.rows();
00295
00296 assert(H.cols()==n);
00297 assert(g.size()==n);
00298 assert(A.cols()==n);
00299 assert(b.size()==A.rows());
00300
00301 if (vnl_vector_ssd(A*x,b)>con_tol)
00302 {
00303 if (verbose)
00304 vcl_cerr<<"Supplied x does not satisfy equality constraints\n";
00305 return false;
00306 }
00307 for (unsigned i=0;i<n;++i)
00308 {
00309 if (x[i]<0)
00310 {
00311 if (verbose)
00312 vcl_cerr<<"Element "<<i<<" of x is negative. Must be >=0 on input.\n";
00313 return false;
00314 }
00315 }
00316
00317
00318 vcl_vector<bool> valid(n,true);
00319 unsigned n_valid=n;
00320
00321 while (!vnl_solve_qp_non_neg_step(H,g,A,b,x,valid,n_valid)) {}
00322
00323 if (vnl_vector_ssd(A*x,b)>con_tol)
00324 {
00325 if (verbose)
00326 vcl_cerr<<"Oops: Final x does not satisfy equality constraints\n";
00327 return false;
00328 }
00329 else
00330 return true;
00331 }
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344 bool vnl_solve_qp_non_neg_sum_one(const vnl_matrix<double>& H,
00345 const vnl_vector<double>& g,
00346 vnl_vector<double>& x,
00347 bool verbose)
00348 {
00349
00350 unsigned n=H.rows();
00351 assert(H.cols()==n);
00352 assert(g.size()==n);
00353
00354 if (vcl_fabs(x.sum()-1.0)>1e-8)
00355 {
00356 if (verbose)
00357 vcl_cerr<<"Supplied x does not sum to unity.\n";
00358 return false;
00359 }
00360 for (unsigned i=0;i<n;++i)
00361 {
00362 if (x[i]<0)
00363 {
00364 if (verbose)
00365 vcl_cerr<<"Element "<<i<<" of x is negative. Must be >=0 on input.\n";
00366 return false;
00367 }
00368 }
00369
00370
00371 vcl_vector<bool> valid(n,true);
00372 unsigned n_valid=n;
00373
00374 while (!vnl_solve_qp_non_neg_sum_one_step(H,g,x,valid,n_valid)) {}
00375
00376 if (vcl_fabs(x.sum()-1.0)>1e-8)
00377 {
00378 if (verbose)
00379 vcl_cerr<<"Oops. Final x does not sum to unity.\n";
00380 return false;
00381 }
00382 else
00383 return true;
00384 }