1 module lbfgsd.autodiff; 2 3 enum isVariable(T) = is(T : Variable!(U, N), U, size_t N); 4 5 unittest 6 { 7 static assert(isVariable!(Variable!(float, 3))); 8 static assert(isVariable!(Variable!(double, 2))); 9 static assert(!isVariable!float); 10 static assert(!isVariable!double); 11 static assert(!isVariable!string); 12 } 13 14 struct Variable(T, size_t N) 15 { 16 public: 17 this(T val) @safe @nogc pure nothrow 18 { 19 d[] = T(0); 20 a = val; 21 } 22 23 this(T val, size_t n) @safe @nogc pure nothrow 24 { 25 assert(n < N); 26 27 d[] = T(0); 28 d[n] = T(1); 29 a = val; 30 } 31 32 public: 33 ref Variable opAssign(const scope T val) @safe @nogc pure nothrow 34 { 35 d[] = T(0); 36 a = val; 37 return this; 38 } 39 40 Variable opUnary(string op)() @safe @nogc const pure nothrow 41 { 42 static if (op == "-") 43 { 44 Variable t; 45 t.d[] = -d[]; 46 t.a = -a; 47 return t; 48 } 49 else 50 static assert(false); 51 } 52 53 Variable opBinary(string op)(const scope Variable r) @safe @nogc const pure nothrow 54 { 55 Variable t; 56 static if (op == "+") 57 { 58 t.d[] = d[] + r.d[]; 59 t.a = a + r.a; 60 } 61 else static if (op == "-") 62 { 63 t.d[] = d[] - r.d[]; 64 t.a = a - r.a; 65 } 66 else static if (op == "*") 67 { 68 t.d[] = d[] * r.a + a * r.d[]; 69 t.a = a * r.a; 70 } 71 else static if (op == "/") 72 { 73 const u = T(1) / r.a; 74 t.d[] = (d[] - a * u * r.d[]) * u; 75 t.a = a * u; 76 } 77 else 78 static assert(false); 79 return t; 80 } 81 82 Variable opBinary(string op)(const scope T r) @safe @nogc const pure nothrow 83 { 84 Variable t; 85 static if (op == "+") 86 { 87 t.d[] = d[]; 88 t.a = a + r; 89 } 90 else static if (op == "-") 91 { 92 t.d[] = d[]; 93 t.a = a - r; 94 } 95 else static if (op == "*") 96 { 97 t.d[] = d[] * r; 98 t.a = a * r; 99 } 100 else static if (op == "/") 101 { 102 t.d[] = d[] / r; 103 t.a = a / r; 104 } 105 else 106 static assert(false); 107 return t; 108 } 109 110 Variable opBinaryRight(string op)(const scope T l) @safe @nogc const pure nothrow 111 { 112 Variable t; 113 static if (op == "+") 114 { 115 t.d[] = d[]; 116 t.a = l + a; 117 } 118 else static if (op == "-") 119 { 120 t.d[] = -d[]; 121 t.a = l - a; 122 } 123 else static if (op == "*") 124 { 125 t.d[] = l * d[]; 126 t.a = l * a; 127 } 128 else static if (op == "/") 129 { 130 t.d[] = -l / d[]; 131 t.a = l / a; 132 } 133 return t; 134 } 135 136 ref Variable opOpAssign(string op)(const scope Variable r) @safe @nogc pure nothrow 137 { 138 static if (op == "+") 139 return this = this + r; 140 else static if (op == "-") 141 return this = this - r; 142 else static if (op == "*") 143 return this = this * r; 144 else static if (op == "/") 145 return this = this / r; 146 else 147 static assert(false); 148 } 149 150 ref Variable opOpAssign(string op)(const scope T r) @safe @nogc pure nothrow 151 { 152 static if (op == "+") 153 return this = this + r; 154 else static if (op == "-") 155 return this = this - r; 156 else static if (op == "*") 157 return this = this * r; 158 else static if (op == "/") 159 return this = this / r; 160 else 161 static assert(false); 162 } 163 164 bool opEquals(const scope T rhs) const 165 { 166 return a == rhs; 167 } 168 169 bool opEquals(const scope Variable rhs) const 170 { 171 return a == rhs.a; 172 } 173 174 int opCmp(const scope T rhs) const 175 { 176 const t = a - rhs; 177 if (t == 0) 178 return 0; 179 if (t < 0) 180 return -1; 181 return 1; 182 } 183 184 int opCmp(const scope Variable rhs) const 185 { 186 const t = a - rhs.a; 187 if (t == 0) 188 return 0; 189 if (t < 0) 190 return -1; 191 return 1; 192 } 193 194 public: 195 T[N] d; 196 T a; 197 } 198 199 @safe pure nothrow unittest 200 { 201 import std.algorithm; 202 203 alias Var = Variable!(double, 3); 204 auto x = Var(1); 205 assert(x.a == 1); 206 assert(equal(x.d[], [0.0, 0.0, 0.0])); 207 x = 2; 208 assert(x.a == 2); 209 assert(equal(x.d[], [0.0, 0.0, 0.0])); 210 } 211 212 @safe pure nothrow unittest 213 { 214 alias Var = Variable!(double, 2); 215 auto x = Var(1, 0); 216 auto y = Var(1, 1); 217 assert(x.d.length == 2); 218 assert(x.d[0] == 1); 219 assert(x.d[1] == 0); 220 assert(y.d.length == 2); 221 assert(y.d[0] == 0); 222 assert(y.d[1] == 1); 223 224 y = -x; 225 assert(y.d[0] == -1); 226 assert(y.d[1] == 0); 227 assert(y.a == -1); 228 229 x = 0; 230 assert(x.d[0] == 0); 231 assert(x.d[1] == 0); 232 assert(x.a == 0); 233 } 234 235 @safe pure nothrow unittest 236 { 237 alias Var = Variable!(double, 2); 238 auto x = Var(1, 0); 239 auto y = Var(2, 1); 240 Var z; 241 242 z = x + y; 243 assert(z.d[0] == 1); 244 assert(z.d[1] == 1); 245 assert(z.a == x.a + y.a); 246 247 z = x - y; 248 assert(z.d[0] == 1); 249 assert(z.d[1] == -1); 250 assert(z.a == x.a - y.a); 251 252 z = x * y; 253 assert(z.d[0] == y.a); 254 assert(z.d[1] == x.a); 255 assert(z.a == x.a * y.a); 256 257 z = x / y; 258 assert(z.d[0] == 0.5); 259 assert(z.d[1] == -0.25); 260 assert(z.a == x.a / y.a); 261 } 262 263 @safe pure nothrow unittest 264 { 265 alias Var = Variable!(double, 1); 266 auto x = Var(1, 0); 267 auto y = 2.0; 268 Var z; 269 270 z = x + y; 271 assert(z.d[0] == 1); 272 assert(z.a == x.a + y); 273 274 z = x - y; 275 assert(z.d[0] == 1); 276 assert(z.a == x.a - y); 277 278 z = x * y; 279 assert(z.d[0] == y); 280 assert(z.a == x.a * y); 281 282 z = x / y; 283 assert(z.d[0] == x.d[0] / y); 284 assert(z.a == x.a / y); 285 } 286 287 @safe pure nothrow unittest 288 { 289 alias Var = Variable!(double, 1); 290 auto x = Var(1, 0); 291 Var y; 292 293 y = 1 + x; 294 assert(y.d[0] == 1); 295 assert(y.a == 1 + x.a); 296 297 y = 1 - x; 298 assert(y.d[0] == -1); 299 assert(y.a == 1 - x.a); 300 301 y = 2 * x; 302 assert(y.d[0] == 2); 303 assert(y.a == 2 * x.a); 304 305 y = 2 / x; 306 assert(y.d[0] == -2); 307 assert(y.a == 2 / x.a); 308 } 309 310 @safe pure nothrow unittest 311 { 312 alias Var = Variable!(double, 2); 313 auto x = Var(1, 0); 314 auto y = Var(2, 1); 315 316 x += y; 317 assert(x.d[0] == 1); 318 assert(x.d[1] == 1); 319 assert(x.a == 3); 320 321 x -= y; 322 assert(x.d[0] == 1); 323 assert(x.d[1] == 0); 324 assert(x.a == 1); 325 326 x *= y; 327 assert(x.d[0] == 2); 328 assert(x.d[1] == 1); 329 assert(x.a == 2); 330 331 x /= y; 332 assert(x.d[0] == 1); 333 assert(x.d[1] == 0); 334 assert(x.a == 1); 335 } 336 337 @safe pure nothrow unittest 338 { 339 alias Var = Variable!(double, 2); 340 auto x = Var(1, 0); 341 342 x += 2; 343 assert(x.d[0] == 1); 344 assert(x.d[1] == 0); 345 assert(x.a == 3); 346 347 x -= 2; 348 assert(x.d[0] == 1); 349 assert(x.d[1] == 0); 350 assert(x.a == 1); 351 352 x *= 2; 353 assert(x.d[0] == 2); 354 assert(x.d[1] == 0); 355 assert(x.a == 2); 356 357 x /= 4; 358 assert(x.d[0] == 0.5); 359 assert(x.d[1] == 0); 360 assert(x.a == 0.5); 361 } 362 363 @safe pure nothrow unittest 364 { 365 alias Var = Variable!(double, 2); 366 auto x = Var(1, 0); 367 368 assert(x > 0); 369 assert(x == 1); 370 assert(x < 2); 371 372 auto y = Var(2, 0); 373 assert(x < y); 374 375 assert(x > Var(0)); 376 assert(x >= Var(0)); 377 assert(x == Var(1)); 378 assert(x <= Var(2)); 379 assert(x < Var(2)); 380 }