1 module lbfgsd.linesearch; 2 3 private import lbfgsd.cost; 4 5 enum LineSearchType 6 { 7 Armijo, 8 Wolfe, 9 StrongWolfe 10 } 11 12 struct LineSearchOptions(T) 13 { 14 LineSearchType type = LineSearchType.Armijo; 15 bool bulkEvaluate = true; //only use by wolfe or strong wolfe 16 17 size_t maxIterations = 20; 18 19 T initialStepSize = 1.0; 20 21 T wolfe = 0.9; 22 T armijo = 1e-4; 23 } 24 25 struct LineSearchResult(T) 26 { 27 bool success; 28 size_t numIterations; 29 T stepSize; 30 T cost; 31 } 32 33 class BackTrackLineSearcher(T, size_t NInput) 34 { 35 alias Cost = CostFunction!(T, NInput); 36 alias Result = LineSearchResult!T; 37 alias Options = LineSearchOptions!T; 38 39 public: 40 ref Options options() @safe @nogc pure nothrow 41 { 42 return _options; 43 } 44 45 public: 46 void setCostFunction(Cost cost) @safe @nogc pure nothrow 47 { 48 _cost = cost; 49 } 50 51 public: 52 Result search(in T[] x, in T[] g, in T[] d, in T f0, T[] xn, T[] gn) 53 { 54 return search(x, g, d, f0, xn, gn, _options.initialStepSize); 55 } 56 57 Result search(in T[] x, in T[] g, in T[] d, in T f0, T[] xn, T[] gn, T step) 58 { 59 Result result = void; 60 result.numIterations = 0; 61 result.success = false; 62 result.stepSize = step; 63 64 import std.numeric : dotProduct; 65 immutable ginit = dotProduct(g, d); 66 immutable c_armijo = _options.armijo * ginit; 67 immutable c_wolfe = _options.wolfe * ginit; 68 immutable type = _options.type; 69 immutable bulk = _options.bulkEvaluate; 70 71 enum inc = 2.1; 72 enum dec = 0.5; 73 T fx; 74 foreach (i; 0 .. _options.maxIterations) 75 { 76 ++result.numIterations; 77 78 xn[] = x[] + step * d[]; 79 fx = bulk 80 ? _cost.evaluate(xn, gn) 81 : _cost.evaluate(xn, null); 82 83 //Armijo 84 if (fx > f0 + step * c_armijo) 85 { 86 step *= dec; 87 continue; 88 } 89 90 if (type == LineSearchType.Armijo) 91 { 92 if (!bulk) _cost.evaluate(xn, gn); //calc gradient when the Armijo method 93 result.success = true; 94 break; 95 } 96 97 //Wolfe 98 if (!bulk) _cost.evaluate(xn, gn); 99 100 const dg = dotProduct(gn, d); 101 if (dg < c_wolfe) 102 { 103 step *= inc; 104 continue; 105 } 106 if (type == LineSearchType.Wolfe) 107 { 108 result.success = true; 109 break; 110 } 111 112 //Strong Wolfe 113 if (dg > -c_wolfe) 114 { 115 step *= dec; 116 continue; 117 } 118 if (type == LineSearchType.StrongWolfe) 119 { 120 result.success = true; 121 break; 122 } 123 } 124 //iteration is over 125 result.stepSize = step; 126 result.cost = fx; 127 return result; 128 } 129 130 private: 131 Cost _cost; 132 Options _options; 133 } 134 unittest 135 { 136 static struct Func 137 { 138 T opCall(T)(in T[] x) @safe @nogc pure nothrow 139 { 140 import lbfgsd.math; 141 auto t1 = x[0] - 1; 142 auto t2 = x[1] + 10; 143 return t1 * t1 + t2 * t2 + exp(x[0] + x[1]); 144 } 145 } 146 147 Func fn; 148 auto cost = new AutoDiffCostFunction!(Func, double, 2)(fn); 149 150 auto searcher = new BackTrackLineSearcher!(double, 2); 151 foreach (t; [LineSearchType.Armijo, LineSearchType.Wolfe, LineSearchType.StrongWolfe]) 152 { 153 searcher.options.type = t; 154 searcher.options.maxIterations = 5; 155 searcher.setCostFunction(cost); 156 157 auto x = new double[2]; 158 auto g = new double[2]; 159 auto xn = new double[2]; 160 auto gn = new double[2]; 161 auto d = new double[2]; 162 163 x[] = 0; 164 auto f = cost.evaluate(x, g); 165 d[] = -g[]; 166 167 auto result = searcher.search(x, g, d, f, xn, gn); 168 169 assert(result.success); 170 assert(result.numIterations <= 5); 171 assert(result.stepSize > 0); 172 } 173 } 174 unittest 175 { 176 import lbfgsd.functions; 177 RosenBrockFunction fn; 178 auto cost = new AutoDiffCostFunction!(RosenBrockFunction, double, 2)(fn); 179 180 auto searcher = new BackTrackLineSearcher!(double, 2); 181 foreach (t; [LineSearchType.Armijo, LineSearchType.Wolfe, LineSearchType.StrongWolfe]) 182 { 183 searcher.options.type = t; 184 searcher.options.maxIterations = 5; 185 searcher.options.bulkEvaluate = false; 186 searcher.setCostFunction(cost); 187 188 auto x = new double[2]; 189 auto g = new double[2]; 190 auto xn = new double[2]; 191 auto gn = new double[2]; 192 auto d = new double[2]; 193 194 x[] = 0; 195 auto f = cost.evaluate(x, g); 196 d[] = -g[]; 197 198 auto result = searcher.search(x, g, d, f, xn, gn); 199 200 assert(result.success); 201 assert(result.numIterations <= 5); 202 assert(result.stepSize > 0); 203 } 204 }