1 module lbfgsd.solver;
2 
3 private import lbfgsd.cost;
4 private import lbfgsd.autodiff;
5 private import lbfgsd.linesearch;
6 
7 private import std.math: sqrt;
8 private import std.numeric: dotProduct;
9 
10 struct SolverOptions(T)
11 {
12     size_t maxIterations = 200;
13     T gradientTolerance = 1e-20;
14 
15     bool estimateStepSize = true;
16     T initialStepSize = 1.0;
17 
18     LineSearchOptions!T linesearch;
19 }
20 
21 struct SolverResult(T)
22 {
23     SolverResultStatus status = SolverResultStatus.Unknown;
24     bool success() const pure nothrow @nogc @safe @property
25     {
26         return status == SolverResultStatus.Success || status == SolverResultStatus.AlreadyMinimized;
27     }
28 
29     T firstCost;
30     T lastCost;
31 
32     SolverIteration!T[] iterations;
33 }
34 
35 struct SolverIteration(T)
36 {
37     bool success;
38 
39     size_t lineSearchIterations;
40     T stepSize;
41 
42     T cost;
43     T paramNorm;
44     T gradientNorm;
45 }
46 
47 enum SolverResultStatus
48 {
49     Unknown,
50     Success,
51     AlreadyMinimized,
52     LinesearchFailed,
53     OverMaxIterations
54 }
55 
56 /**
57  *L-BFGS Solver
58  */
59 class SimpleSolver(T, size_t NInput, size_t NLBFGS = 6)
60 {
61     alias Cost = CostFunction!(T, NInput);
62     alias Options = SolverOptions!T;
63     alias Result = SolverResult!T;
64 
65 public:
66     this(Options options = Options.init)
67     {
68         _options = options;
69         _searcher = new BackTrackLineSearcher!(T, NInput);
70     }
71 
72 public:
73     ref Options options() @safe @nogc pure nothrow
74     {
75         return _options;
76     }
77 
78 public:
79     void setAutoDiffCost(TFunc)(TFunc fn)
80     {
81         _cost = new AutoDiffCostFunction!(TFunc, T, NInput)(fn);
82     }
83     void setNumericDiffCost(TFunc)(TFunc fn)
84     {
85         _cost = new NumericDiffCostFunction!(TFunc, T, NInput)(fn);
86     }
87     void setCostFunction(Cost cost)
88     {
89         _cost = cost;
90     }
91 
92 public:
93     Result solve(T[] x)
94     {
95         Result result;
96 
97         _searcher.options = _options.linesearch;
98         _searcher.setCostFunction(_cost);
99 
100         //current
101         T[NInput] xc;
102         xc[] = x[];
103         T[NInput] gc;
104         //prev
105         T[NInput] xp;
106         T[NInput] gp;
107         //search vector
108         T[NInput] sv;
109 
110         //L-BFGS
111         static if (NLBFGS > 0)
112         {
113             static struct LBFGSIterateData
114             {
115                 T alpha;
116                 T[NInput] y;
117                 T[NInput] s;
118                 T iys;
119             }
120 
121             LBFGSIterateData[NLBFGS] buf;
122             auto bufPos = 0;
123             foreach (ref d; buf)
124             {
125                 d.alpha = 0;
126                 d.iys = 0;
127             }
128         }
129 
130         auto fx = _cost.evaluate(xc, gc);
131         result.firstCost = fx;
132 
133         T xnorm = 0;
134         T gnorm = 0;
135         foreach (j; 0 .. NInput)
136         {
137             xnorm += xc[j] * xc[j];
138             gnorm += gc[j] * gc[j];
139         }
140 
141         if (xnorm < 1) xnorm = 1;
142         if (gnorm < xnorm * _options.gradientTolerance)
143         {
144             //already minimized
145             result.status = SolverResultStatus.AlreadyMinimized;
146             result.lastCost = fx;
147             return result;
148         }
149 
150         //H_0 is identity matrix
151         sv[] = -gc[];
152 
153         //for linesearch
154         T step = _options.estimateStepSize
155             ? 1.0 / sqrt(gnorm)
156             : _options.initialStepSize;
157 
158         auto loop = 1;
159         for (;;)
160         {
161             SolverIteration!T iteration;
162             scope(exit) { result.iterations ~= iteration; }
163 
164             //store
165             xp[] = xc[];
166             gp[] = gc[];
167 
168             //linesearch
169             auto lr = _searcher.search(xp, gp, sv, fx, xc, gc, step);
170             fx = lr.cost;
171             iteration.lineSearchIterations = lr.numIterations;
172             iteration.success = lr.success;
173             iteration.cost = lr.cost;
174             iteration.stepSize = lr.stepSize;
175             if (!lr.success)
176             {
177                 //linesearch failed
178                 result.status = SolverResultStatus.LinesearchFailed;
179                 //restore
180                 xc[] = xp[];
181                 gc[] = gp[];
182                 break;
183             }
184 
185             //check the gradient
186             xnorm = 0;
187             gnorm = 0;
188             foreach (j; 0 .. NInput)
189             {
190                 xnorm += xc[j] * xc[j];
191                 gnorm += gc[j] * gc[j];
192             }
193             iteration.paramNorm = xnorm;
194             iteration.gradientNorm = gnorm;
195 
196             if (xnorm < 1) xnorm = 1;
197             if (gnorm < xnorm * _options.gradientTolerance)
198             {
199                 //convergence
200                 result.status = SolverResultStatus.Success;
201                 break;
202             }
203 
204             if (loop >= _options.maxIterations)
205             {
206                 //iterations is over
207                 result.status = SolverResultStatus.OverMaxIterations;
208                 break;
209             }
210 
211             static if (NLBFGS > 0)
212             {
213                 //update
214                 buf[bufPos].y[] = xc[] - xp[];
215                 buf[bufPos].s[] = gc[] - gp[];
216 
217                 const ys = dotProduct(buf[bufPos].y, buf[bufPos].s);
218                 if (ys == 0)
219                 {
220                     result.status = SolverResultStatus.Success;
221                     break;
222                 }
223 
224                 buf[bufPos].iys = 1.0 / ys;
225                 const yy = dotProduct(buf[bufPos].y, buf[bufPos].y);
226                 if (yy == 0)
227                 {
228                     result.status = SolverResultStatus.Success;
229                     break;
230                 }
231                 const iyy = ys / yy;
232             }
233 
234             //compute the search vector
235             sv[] = -gc[];
236             static if (NLBFGS > 0)
237             {
238                 //L-BFGS
239                 import std.algorithm : min;
240                 immutable bound = min(loop, NLBFGS);
241                 auto j = bufPos = (bufPos + 1) % NLBFGS;
242 
243                 foreach (_; 0 .. bound)
244                 {
245                     j = (j + NLBFGS - 1) % NLBFGS;
246                     buf[j].alpha = dotProduct(buf[j].s, sv) * buf[j].iys;
247                     sv[] -= buf[j].alpha * buf[j].y[];
248                 }
249                 sv[] *= iyy;
250                 foreach (_; 0 .. bound)
251                 {
252                     const beta = dotProduct(buf[j].y, sv) * buf[j].iys;
253                     sv[] += (buf[j].alpha - beta) * buf[j].s[];
254                     j = (j + 1) % NLBFGS;
255                 }
256             }
257 
258             //prepare for a next
259             ++loop;
260             step = _options.estimateStepSize
261                 ? 1.0 / sqrt(dotProduct(sv, sv))
262                 : _options.initialStepSize;
263         }
264         x[] = xc[];
265         result.lastCost = fx;
266 
267         return result;
268     }
269 
270 private:
271     Cost _cost;
272     BackTrackLineSearcher!(T, NInput) _searcher;
273     Options _options;
274 }
275 
276 
277 unittest
278 {
279     auto solver = new SimpleSolver!(double, 3);
280     solver.options.linesearch.type = LineSearchType.StrongWolfe;
281     solver.options.linesearch.maxIterations = 50;
282     solver.options.estimateStepSize = false;
283     solver.options.initialStepSize = 0.5;
284     solver.options.maxIterations = 50;
285 
286     static struct Func
287     {
288         T opCall(T)(in T[] x)
289         {
290             import lbfgsd.math;
291             auto t0 = x[0] + x[1] - 1;
292             auto t1 = x[1] + x[2] + 5;
293             auto t2 = x[2] + x[0] + 3;
294             return square(t0) + square(t1) + square(t2);
295         }
296     }
297     Func fn;
298     solver.setAutoDiffCost(fn);
299 
300     auto x = new double[3];
301     x[] = 0.5;
302     auto result = solver.solve(x);
303 
304     assert(result.success);
305     assert(result.iterations.length <= 50);
306     assert(result.firstCost > 30);
307     assert(result.lastCost < 1e-10);
308 }
309 
310 unittest
311 {
312     import lbfgsd.functions;
313 
314     auto solver = new SimpleSolver!(double, 3);
315     solver.options.linesearch.type = LineSearchType.StrongWolfe;
316     solver.options.linesearch.maxIterations = 10;
317     solver.options.estimateStepSize = true;
318     solver.options.maxIterations = 50;
319 
320     RosenBrockFunction fn;
321     solver.setNumericDiffCost(fn);
322 
323     auto x = new double[3];
324     x[0] = -1.2;
325     x[1] = 0.4;
326     x[2] = -0.1;
327     auto result = solver.solve(x);
328 
329     assert(!result.success);
330     assert(result.iterations.length == 50);
331     assert(result.firstCost > 30);
332     assert(result.lastCost < 5);
333 }