1 module lbfgsd.solver; 2 3 private import lbfgsd.cost; 4 private import lbfgsd.autodiff; 5 private import lbfgsd.linesearch; 6 7 private import std.math: sqrt; 8 private import std.numeric: dotProduct; 9 10 struct SolverOptions(T) 11 { 12 size_t maxIterations = 200; 13 T gradientTolerance = 1e-20; 14 15 bool estimateStepSize = true; 16 T initialStepSize = 1.0; 17 18 LineSearchOptions!T linesearch; 19 } 20 21 struct SolverResult(T) 22 { 23 SolverResultStatus status = SolverResultStatus.Unknown; 24 bool success() const pure nothrow @nogc @safe @property 25 { 26 return status == SolverResultStatus.Success || status == SolverResultStatus.AlreadyMinimized; 27 } 28 29 T firstCost; 30 T lastCost; 31 32 SolverIteration!T[] iterations; 33 } 34 35 struct SolverIteration(T) 36 { 37 bool success; 38 39 size_t lineSearchIterations; 40 T stepSize; 41 42 T cost; 43 T paramNorm; 44 T gradientNorm; 45 } 46 47 enum SolverResultStatus 48 { 49 Unknown, 50 Success, 51 AlreadyMinimized, 52 LinesearchFailed, 53 OverMaxIterations 54 } 55 56 /** 57 *L-BFGS Solver 58 */ 59 class SimpleSolver(T, size_t NInput, size_t NLBFGS = 6) 60 { 61 alias Cost = CostFunction!(T, NInput); 62 alias Options = SolverOptions!T; 63 alias Result = SolverResult!T; 64 65 public: 66 this(Options options = Options.init) 67 { 68 _options = options; 69 _searcher = new BackTrackLineSearcher!(T, NInput); 70 } 71 72 public: 73 ref Options options() @safe @nogc pure nothrow 74 { 75 return _options; 76 } 77 78 public: 79 void setAutoDiffCost(TFunc)(TFunc fn) 80 { 81 _cost = new AutoDiffCostFunction!(TFunc, T, NInput)(fn); 82 } 83 void setNumericDiffCost(TFunc)(TFunc fn) 84 { 85 _cost = new NumericDiffCostFunction!(TFunc, T, NInput)(fn); 86 } 87 void setCostFunction(Cost cost) 88 { 89 _cost = cost; 90 } 91 92 public: 93 Result solve(T[] x) 94 { 95 Result result; 96 97 _searcher.options = _options.linesearch; 98 _searcher.setCostFunction(_cost); 99 100 //current 101 T[NInput] xc; 102 xc[] = x[]; 103 T[NInput] gc; 104 //prev 105 T[NInput] xp; 106 T[NInput] gp; 107 //search vector 108 T[NInput] sv; 109 110 //L-BFGS 111 static if (NLBFGS > 0) 112 { 113 static struct LBFGSIterateData 114 { 115 T alpha; 116 T[NInput] y; 117 T[NInput] s; 118 T iys; 119 } 120 121 LBFGSIterateData[NLBFGS] buf; 122 auto bufPos = 0; 123 foreach (ref d; buf) 124 { 125 d.alpha = 0; 126 d.iys = 0; 127 } 128 } 129 130 auto fx = _cost.evaluate(xc, gc); 131 result.firstCost = fx; 132 133 T xnorm = 0; 134 T gnorm = 0; 135 foreach (j; 0 .. NInput) 136 { 137 xnorm += xc[j] * xc[j]; 138 gnorm += gc[j] * gc[j]; 139 } 140 141 if (xnorm < 1) xnorm = 1; 142 if (gnorm < xnorm * _options.gradientTolerance) 143 { 144 //already minimized 145 result.status = SolverResultStatus.AlreadyMinimized; 146 result.lastCost = fx; 147 return result; 148 } 149 150 //H_0 is identity matrix 151 sv[] = -gc[]; 152 153 //for linesearch 154 T step = _options.estimateStepSize 155 ? 1.0 / sqrt(gnorm) 156 : _options.initialStepSize; 157 158 auto loop = 1; 159 for (;;) 160 { 161 SolverIteration!T iteration; 162 scope(exit) { result.iterations ~= iteration; } 163 164 //store 165 xp[] = xc[]; 166 gp[] = gc[]; 167 168 //linesearch 169 auto lr = _searcher.search(xp, gp, sv, fx, xc, gc, step); 170 fx = lr.cost; 171 iteration.lineSearchIterations = lr.numIterations; 172 iteration.success = lr.success; 173 iteration.cost = lr.cost; 174 iteration.stepSize = lr.stepSize; 175 if (!lr.success) 176 { 177 //linesearch failed 178 result.status = SolverResultStatus.LinesearchFailed; 179 //restore 180 xc[] = xp[]; 181 gc[] = gp[]; 182 break; 183 } 184 185 //check the gradient 186 xnorm = 0; 187 gnorm = 0; 188 foreach (j; 0 .. NInput) 189 { 190 xnorm += xc[j] * xc[j]; 191 gnorm += gc[j] * gc[j]; 192 } 193 iteration.paramNorm = xnorm; 194 iteration.gradientNorm = gnorm; 195 196 if (xnorm < 1) xnorm = 1; 197 if (gnorm < xnorm * _options.gradientTolerance) 198 { 199 //convergence 200 result.status = SolverResultStatus.Success; 201 break; 202 } 203 204 if (loop >= _options.maxIterations) 205 { 206 //iterations is over 207 result.status = SolverResultStatus.OverMaxIterations; 208 break; 209 } 210 211 static if (NLBFGS > 0) 212 { 213 //update 214 buf[bufPos].y[] = xc[] - xp[]; 215 buf[bufPos].s[] = gc[] - gp[]; 216 217 const ys = dotProduct(buf[bufPos].y, buf[bufPos].s); 218 if (ys == 0) 219 { 220 result.status = SolverResultStatus.Success; 221 break; 222 } 223 224 buf[bufPos].iys = 1.0 / ys; 225 const yy = dotProduct(buf[bufPos].y, buf[bufPos].y); 226 if (yy == 0) 227 { 228 result.status = SolverResultStatus.Success; 229 break; 230 } 231 const iyy = ys / yy; 232 } 233 234 //compute the search vector 235 sv[] = -gc[]; 236 static if (NLBFGS > 0) 237 { 238 //L-BFGS 239 import std.algorithm : min; 240 immutable bound = min(loop, NLBFGS); 241 auto j = bufPos = (bufPos + 1) % NLBFGS; 242 243 foreach (_; 0 .. bound) 244 { 245 j = (j + NLBFGS - 1) % NLBFGS; 246 buf[j].alpha = dotProduct(buf[j].s, sv) * buf[j].iys; 247 sv[] -= buf[j].alpha * buf[j].y[]; 248 } 249 sv[] *= iyy; 250 foreach (_; 0 .. bound) 251 { 252 const beta = dotProduct(buf[j].y, sv) * buf[j].iys; 253 sv[] += (buf[j].alpha - beta) * buf[j].s[]; 254 j = (j + 1) % NLBFGS; 255 } 256 } 257 258 //prepare for a next 259 ++loop; 260 step = _options.estimateStepSize 261 ? 1.0 / sqrt(dotProduct(sv, sv)) 262 : _options.initialStepSize; 263 } 264 x[] = xc[]; 265 result.lastCost = fx; 266 267 return result; 268 } 269 270 private: 271 Cost _cost; 272 BackTrackLineSearcher!(T, NInput) _searcher; 273 Options _options; 274 } 275 276 277 unittest 278 { 279 auto solver = new SimpleSolver!(double, 3); 280 solver.options.linesearch.type = LineSearchType.StrongWolfe; 281 solver.options.linesearch.maxIterations = 50; 282 solver.options.estimateStepSize = false; 283 solver.options.initialStepSize = 0.5; 284 solver.options.maxIterations = 50; 285 286 static struct Func 287 { 288 T opCall(T)(in T[] x) 289 { 290 import lbfgsd.math; 291 auto t0 = x[0] + x[1] - 1; 292 auto t1 = x[1] + x[2] + 5; 293 auto t2 = x[2] + x[0] + 3; 294 return square(t0) + square(t1) + square(t2); 295 } 296 } 297 Func fn; 298 solver.setAutoDiffCost(fn); 299 300 auto x = new double[3]; 301 x[] = 0.5; 302 auto result = solver.solve(x); 303 304 assert(result.success); 305 assert(result.iterations.length <= 50); 306 assert(result.firstCost > 30); 307 assert(result.lastCost < 1e-10); 308 } 309 310 unittest 311 { 312 import lbfgsd.functions; 313 314 auto solver = new SimpleSolver!(double, 3); 315 solver.options.linesearch.type = LineSearchType.StrongWolfe; 316 solver.options.linesearch.maxIterations = 10; 317 solver.options.estimateStepSize = true; 318 solver.options.maxIterations = 50; 319 320 RosenBrockFunction fn; 321 solver.setNumericDiffCost(fn); 322 323 auto x = new double[3]; 324 x[0] = -1.2; 325 x[1] = 0.4; 326 x[2] = -0.1; 327 auto result = solver.solve(x); 328 329 assert(!result.success); 330 assert(result.iterations.length == 50); 331 assert(result.firstCost > 30); 332 assert(result.lastCost < 5); 333 }