1 module lbfgsd.linesearch;
2 
3 private import lbfgsd.cost;
4 
5 enum LineSearchType
6 {
7     Armijo,
8     Wolfe,
9     StrongWolfe
10 }
11 
12 struct LineSearchOptions(T)
13 {
14     LineSearchType type = LineSearchType.Armijo;
15     bool bulkEvaluate = true; //only use by wolfe or strong wolfe
16 
17     size_t maxIterations = 20;
18 
19     T initialStepSize = 1.0;
20 
21     T wolfe = 0.9;
22     T armijo = 1e-4;
23 }
24 
25 struct LineSearchResult(T)
26 {
27     bool success;
28     size_t numIterations;
29     T stepSize;
30     T cost;
31 }
32 
33 class BackTrackLineSearcher(T, size_t NInput)
34 {
35     alias Cost = CostFunction!(T, NInput);
36     alias Result = LineSearchResult!T;
37     alias Options = LineSearchOptions!T;
38 
39 public:
40     ref Options options() @safe @nogc pure nothrow
41     {
42         return _options;
43     }
44 
45 public:
46     void setCostFunction(Cost cost) @safe @nogc pure nothrow
47     {
48         _cost = cost;
49     }
50 
51 public:
52     Result search(in T[] x, in T[] g, in T[] d, in T f0, T[] xn, T[] gn)
53     {
54         return search(x, g, d, f0, xn, gn, _options.initialStepSize);
55     }
56 
57     Result search(in T[] x, in T[] g, in T[] d, in T f0, T[] xn, T[] gn, T step)
58     {
59         Result result = void;
60         result.numIterations = 0;
61         result.success = false;
62         result.stepSize = step;
63 
64         import std.numeric : dotProduct;
65         immutable ginit = dotProduct(g, d);
66         immutable c_armijo = _options.armijo * ginit;
67         immutable c_wolfe = _options.wolfe * ginit;
68         immutable type = _options.type;
69         immutable bulk = _options.bulkEvaluate;
70 
71         enum inc = 2.1;
72         enum dec = 0.5;
73         T fx;
74         foreach (i; 0 .. _options.maxIterations)
75         {
76             ++result.numIterations;
77 
78             xn[] = x[] + step * d[];
79             fx = bulk
80                 ? _cost.evaluate(xn, gn)
81                 : _cost.evaluate(xn, null);
82 
83             //Armijo
84             if (fx > f0 + step * c_armijo)
85             {
86                 step *= dec;
87                 continue;
88             }
89 
90             if (type == LineSearchType.Armijo)
91             {
92                 if (!bulk) _cost.evaluate(xn, gn); //calc gradient when the Armijo method
93                 result.success = true;
94                 break;
95             }
96 
97             //Wolfe
98             if (!bulk) _cost.evaluate(xn, gn);
99 
100             const dg = dotProduct(gn, d);
101             if (dg < c_wolfe)
102             {
103                 step *= inc;
104                 continue;
105             }
106             if (type == LineSearchType.Wolfe)
107             {
108                 result.success = true;
109                 break;
110             }
111 
112             //Strong Wolfe
113             if (dg > -c_wolfe)
114             {
115                 step *= dec;
116                 continue;
117             }
118             if (type == LineSearchType.StrongWolfe)
119             {
120                 result.success = true;
121                 break;
122             }
123         }
124         //iteration is over
125         result.stepSize = step;
126         result.cost = fx;
127         return result;
128     }
129 
130 private:
131     Cost _cost;
132     Options _options;
133 }
134 unittest
135 {
136     static struct Func
137     {
138         T opCall(T)(in T[] x) @safe @nogc pure nothrow
139         {
140             import lbfgsd.math;
141             auto t1 = x[0] - 1;
142             auto t2 = x[1] + 10;
143             return t1 * t1 + t2 * t2 + exp(x[0] + x[1]);
144         }
145     }
146 
147     Func fn;
148     auto cost = new AutoDiffCostFunction!(Func, double, 2)(fn);
149 
150     auto searcher = new BackTrackLineSearcher!(double, 2);
151     foreach (t; [LineSearchType.Armijo, LineSearchType.Wolfe, LineSearchType.StrongWolfe])
152     {
153         searcher.options.type = t;
154         searcher.options.maxIterations = 5;
155         searcher.setCostFunction(cost);
156 
157         auto x = new double[2];
158         auto g = new double[2];
159         auto xn = new double[2];
160         auto gn = new double[2];
161         auto d = new double[2];
162 
163         x[] = 0;
164         auto f = cost.evaluate(x, g);
165         d[] = -g[];
166 
167         auto result = searcher.search(x, g, d, f, xn, gn);
168 
169         assert(result.success);
170         assert(result.numIterations <= 5);
171         assert(result.stepSize > 0);
172     }
173 }
174 unittest
175 {
176     import lbfgsd.functions;
177     RosenBrockFunction fn;
178     auto cost = new AutoDiffCostFunction!(RosenBrockFunction, double, 2)(fn);
179 
180     auto searcher = new BackTrackLineSearcher!(double, 2);
181     foreach (t; [LineSearchType.Armijo, LineSearchType.Wolfe, LineSearchType.StrongWolfe])
182     {
183         searcher.options.type = t;
184         searcher.options.maxIterations = 5;
185         searcher.options.bulkEvaluate = false;
186         searcher.setCostFunction(cost);
187 
188         auto x = new double[2];
189         auto g = new double[2];
190         auto xn = new double[2];
191         auto gn = new double[2];
192         auto d = new double[2];
193 
194         x[] = 0;
195         auto f = cost.evaluate(x, g);
196         d[] = -g[];
197 
198         auto result = searcher.search(x, g, d, f, xn, gn);
199 
200         assert(result.success);
201         assert(result.numIterations <= 5);
202         assert(result.stepSize > 0);
203     }
204 }