1 module lbfgsd.autodiff;
2 
3 enum isVariable(T) = is(T : Variable!(U, N), U, size_t N);
4 
5 unittest
6 {
7     static assert(isVariable!(Variable!(float, 3)));
8     static assert(isVariable!(Variable!(double, 2)));
9     static assert(!isVariable!float);
10     static assert(!isVariable!double);
11     static assert(!isVariable!string);
12 }
13 
14 struct Variable(T, size_t N)
15 {
16 public:
17     this(T val) @safe @nogc pure nothrow
18     {
19         d[] = T(0);
20         a = val;
21     }
22 
23     this(T val, size_t n) @safe @nogc pure nothrow
24     {
25         assert(n < N);
26 
27         d[] = T(0);
28         d[n] = T(1);
29         a = val;
30     }
31 
32 public:
33     ref Variable opAssign(const scope T val) @safe @nogc pure nothrow
34     {
35         d[] = T(0);
36         a = val;
37         return this;
38     }
39 
40     Variable opUnary(string op)() @safe @nogc const pure nothrow
41     {
42         static if (op == "-")
43         {
44             Variable t;
45             t.d[] = -d[];
46             t.a = -a;
47             return t;
48         }
49         else
50             static assert(false);
51     }
52 
53     Variable opBinary(string op)(const scope Variable r) @safe @nogc const pure nothrow
54     {
55         Variable t;
56         static if (op == "+")
57         {
58             t.d[] = d[] + r.d[];
59             t.a = a + r.a;
60         }
61         else static if (op == "-")
62         {
63             t.d[] = d[] - r.d[];
64             t.a = a - r.a;
65         }
66         else static if (op == "*")
67         {
68             t.d[] = d[] * r.a + a * r.d[];
69             t.a = a * r.a;
70         }
71         else static if (op == "/")
72         {
73             const u = T(1) / r.a;
74             t.d[] = (d[] - a * u * r.d[]) * u;
75             t.a = a * u;
76         }
77         else
78             static assert(false);
79         return t;
80     }
81 
82     Variable opBinary(string op)(const scope T r) @safe @nogc const pure nothrow
83     {
84         Variable t;
85         static if (op == "+")
86         {
87             t.d[] = d[];
88             t.a = a + r;
89         }
90         else static if (op == "-")
91         {
92             t.d[] = d[];
93             t.a = a - r;
94         }
95         else static if (op == "*")
96         {
97             t.d[] = d[] * r;
98             t.a = a * r;
99         }
100         else static if (op == "/")
101         {
102             t.d[] = d[] / r;
103             t.a = a / r;
104         }
105         else
106             static assert(false);
107         return t;
108     }
109 
110     Variable opBinaryRight(string op)(const scope T l) @safe @nogc const pure nothrow
111     {
112         Variable t;
113         static if (op == "+")
114         {
115             t.d[] = d[];
116             t.a = l + a;
117         }
118         else static if (op == "-")
119         {
120             t.d[] = -d[];
121             t.a = l - a;
122         }
123         else static if (op == "*")
124         {
125             t.d[] = l * d[];
126             t.a = l * a;
127         }
128         else static if (op == "/")
129         {
130             t.d[] = -l / d[];
131             t.a = l / a;
132         }
133         return t;
134     }
135 
136     ref Variable opOpAssign(string op)(const scope Variable r) @safe @nogc pure nothrow
137     {
138         static if (op == "+")
139             return this = this + r;
140         else static if (op == "-")
141             return this = this - r;
142         else static if (op == "*")
143             return this = this * r;
144         else static if (op == "/")
145             return this = this / r;
146         else
147             static assert(false);
148     }
149 
150     ref Variable opOpAssign(string op)(const scope T r) @safe @nogc pure nothrow
151     {
152         static if (op == "+")
153             return this = this + r;
154         else static if (op == "-")
155             return this = this - r;
156         else static if (op == "*")
157             return this = this * r;
158         else static if (op == "/")
159             return this = this / r;
160         else
161             static assert(false);
162     }
163 
164     bool opEquals(const scope T rhs) const
165     {
166         return a == rhs;
167     }
168 
169     bool opEquals(const scope Variable rhs) const
170     {
171         return a == rhs.a;
172     }
173 
174     int opCmp(const scope T rhs) const
175     {
176         const t = a - rhs;
177         if (t == 0)
178             return 0;
179         if (t < 0)
180             return -1;
181         return 1;
182     }
183 
184     int opCmp(const scope Variable rhs) const
185     {
186         const t = a - rhs.a;
187         if (t == 0)
188             return 0;
189         if (t < 0)
190             return -1;
191         return 1;
192     }
193 
194 public:
195     T[N] d;
196     T a;
197 }
198 
199 @safe pure nothrow unittest
200 {
201     import std.algorithm;
202 
203     alias Var = Variable!(double, 3);
204     auto x = Var(1);
205     assert(x.a == 1);
206     assert(equal(x.d[], [0.0, 0.0, 0.0]));
207     x = 2;
208     assert(x.a == 2);
209     assert(equal(x.d[], [0.0, 0.0, 0.0]));
210 }
211 
212 @safe pure nothrow unittest
213 {
214     alias Var = Variable!(double, 2);
215     auto x = Var(1, 0);
216     auto y = Var(1, 1);
217     assert(x.d.length == 2);
218     assert(x.d[0] == 1);
219     assert(x.d[1] == 0);
220     assert(y.d.length == 2);
221     assert(y.d[0] == 0);
222     assert(y.d[1] == 1);
223 
224     y = -x;
225     assert(y.d[0] == -1);
226     assert(y.d[1] == 0);
227     assert(y.a == -1);
228 
229     x = 0;
230     assert(x.d[0] == 0);
231     assert(x.d[1] == 0);
232     assert(x.a == 0);
233 }
234 
235 @safe pure nothrow unittest
236 {
237     alias Var = Variable!(double, 2);
238     auto x = Var(1, 0);
239     auto y = Var(2, 1);
240     Var z;
241 
242     z = x + y;
243     assert(z.d[0] == 1);
244     assert(z.d[1] == 1);
245     assert(z.a == x.a + y.a);
246 
247     z = x - y;
248     assert(z.d[0] == 1);
249     assert(z.d[1] == -1);
250     assert(z.a == x.a - y.a);
251 
252     z = x * y;
253     assert(z.d[0] == y.a);
254     assert(z.d[1] == x.a);
255     assert(z.a == x.a * y.a);
256 
257     z = x / y;
258     assert(z.d[0] == 0.5);
259     assert(z.d[1] == -0.25);
260     assert(z.a == x.a / y.a);
261 }
262 
263 @safe pure nothrow unittest
264 {
265     alias Var = Variable!(double, 1);
266     auto x = Var(1, 0);
267     auto y = 2.0;
268     Var z;
269 
270     z = x + y;
271     assert(z.d[0] == 1);
272     assert(z.a == x.a + y);
273 
274     z = x - y;
275     assert(z.d[0] == 1);
276     assert(z.a == x.a - y);
277 
278     z = x * y;
279     assert(z.d[0] == y);
280     assert(z.a == x.a * y);
281 
282     z = x / y;
283     assert(z.d[0] == x.d[0] / y);
284     assert(z.a == x.a / y);
285 }
286 
287 @safe pure nothrow unittest
288 {
289     alias Var = Variable!(double, 1);
290     auto x = Var(1, 0);
291     Var y;
292 
293     y = 1 + x;
294     assert(y.d[0] == 1);
295     assert(y.a == 1 + x.a);
296 
297     y = 1 - x;
298     assert(y.d[0] == -1);
299     assert(y.a == 1 - x.a);
300 
301     y = 2 * x;
302     assert(y.d[0] == 2);
303     assert(y.a == 2 * x.a);
304 
305     y = 2 / x;
306     assert(y.d[0] == -2);
307     assert(y.a == 2 / x.a);
308 }
309 
310 @safe pure nothrow unittest
311 {
312     alias Var = Variable!(double, 2);
313     auto x = Var(1, 0);
314     auto y = Var(2, 1);
315 
316     x += y;
317     assert(x.d[0] == 1);
318     assert(x.d[1] == 1);
319     assert(x.a == 3);
320 
321     x -= y;
322     assert(x.d[0] == 1);
323     assert(x.d[1] == 0);
324     assert(x.a == 1);
325 
326     x *= y;
327     assert(x.d[0] == 2);
328     assert(x.d[1] == 1);
329     assert(x.a == 2);
330 
331     x /= y;
332     assert(x.d[0] == 1);
333     assert(x.d[1] == 0);
334     assert(x.a == 1);
335 }
336 
337 @safe pure nothrow unittest
338 {
339     alias Var = Variable!(double, 2);
340     auto x = Var(1, 0);
341 
342     x += 2;
343     assert(x.d[0] == 1);
344     assert(x.d[1] == 0);
345     assert(x.a == 3);
346 
347     x -= 2;
348     assert(x.d[0] == 1);
349     assert(x.d[1] == 0);
350     assert(x.a == 1);
351 
352     x *= 2;
353     assert(x.d[0] == 2);
354     assert(x.d[1] == 0);
355     assert(x.a == 2);
356 
357     x /= 4;
358     assert(x.d[0] == 0.5);
359     assert(x.d[1] == 0);
360     assert(x.a == 0.5);
361 }
362 
363 @safe pure nothrow unittest
364 {
365     alias Var = Variable!(double, 2);
366     auto x = Var(1, 0);
367 
368     assert(x > 0);
369     assert(x == 1);
370     assert(x < 2);
371 
372     auto y = Var(2, 0);
373     assert(x < y);
374 
375     assert(x > Var(0));
376     assert(x >= Var(0));
377     assert(x == Var(1));
378     assert(x <= Var(2));
379     assert(x < Var(2));
380 }