1 module expression; 2 3 import std.algorithm : map, find; 4 import std.conv : parse, to; 5 import std.format : formattedWrite; 6 import std.math : isInfinity; 7 import std.meta : AliasSeq; 8 import std.range : ElementType, iota, isInputRange, repeat; 9 import std.string : format, join; 10 import std.traits : arity, isSomeChar, isFloatingPoint; 11 import std.typecons : Nullable; 12 import std.uni : isAlpha, isAlphaNum, isNumber, isWhite; 13 import std.utf : toUTF8; 14 15 class ExpressionError : Exception 16 { 17 this(string msg, string file = __FILE__, size_t line = __LINE__) 18 { 19 super(msg, file, line); 20 } 21 22 this(S)(ref const S source, string msg, string file = __FILE__, size_t line = __LINE__) 23 { 24 super("%s (%s)".format(msg, source.pos + 1), file, line); 25 } 26 } 27 28 struct Expression(V) 29 { 30 @disable this(); 31 32 this(R)(R source) if (isInputRange!R && isSomeChar!(ElementType!R)) 33 { 34 auto src = Source!R(source); 35 m_root = compileTertiary(src, m_context); 36 skipWS(src); 37 if (!src.empty()) 38 { 39 throw new ExpressionError(src, "syntax error, unexpected '" ~ src.front.to!string ~ "'"); 40 } 41 } 42 43 V opCall() 44 { 45 validate(); 46 return m_root(); 47 } 48 49 void toString(W)(ref W writer) const 50 { 51 void writeNode(const Node!V node, string indent = "") 52 { 53 writer.formattedWrite("%s%s\n", indent, node); 54 foreach (child; node.children) 55 writeNode(child, indent ~ " "); 56 } 57 58 writeNode(m_root); 59 } 60 61 void opIndexAssign(V value, string name) 62 { 63 auto p = name in m_context.variables; 64 if (p is null) 65 throw new ExpressionError("undefined variable: " ~ name); 66 (*p).set(value); 67 } 68 69 void opIndexAssign(F)(F fn, string name) 70 { 71 // check function type; 72 mixin("alias VS = AliasSeq!(" ~ "V.init".repeat(arity!F).join(",") ~ ");"); 73 V v; 74 static assert(__traits(compiles, v = fn(VS)), "invalid function type:" ~ F.stringof); 75 76 auto p = (name ~= arity!F.to!string) in m_context.functions; 77 if (p is null) 78 throw new ExpressionError(format!"undefined function (%s/%s)"(name, arity!F)); 79 (*p).set(fn); 80 } 81 82 template opDispatch(string name) 83 { 84 void opDispatch(V)(V value) 85 { 86 this[name] = value; 87 } 88 } 89 90 auto variables() 91 { 92 return m_context.variables.keys; 93 } 94 95 auto functions() 96 { 97 return m_context.functions.keys; 98 } 99 100 private: 101 Node!V m_root; 102 Context!V m_context; 103 bool m_validated = false; 104 105 void validate() 106 { 107 if (!m_validated) 108 { 109 foreach (var; m_context.variables.values) 110 { 111 if (!var.isSet) 112 throw new ExpressionError("uninitialized variable: " ~ var.name); 113 } 114 foreach (fn; m_context.functions.values) 115 { 116 if (!fn.isSet) 117 throw new ExpressionError("uninitialized function: " ~ fn.name); 118 } 119 m_validated = true; 120 } 121 } 122 } 123 124 Expression!V compileExpression(V = float, R)(R source) 125 { 126 return Expression!V(source); 127 } 128 129 private: 130 131 struct Context(V) 132 { 133 Variable!V[string] variables; 134 Function!V[string] functions; 135 auto defineVariable(string identifier) 136 { 137 auto p = identifier in variables; 138 if (p !is null) 139 return *p; 140 return variables[identifier] = new Variable!V(identifier); 141 } 142 143 auto defineFunction(string identifier, Node!V[] args...) 144 { 145 identifier ~= args.length.to!string; 146 auto p = identifier in functions; 147 if (p !is null) 148 return *p; 149 return functions[identifier] = new Function!V(identifier, args); 150 } 151 } 152 153 alias compileExpr = compileTertiary; 154 155 Node!V compileTertiary(R, V)(ref Source!R src, ref Context!V ctx) 156 { 157 auto arg = compileComparison(src, ctx); 158 skipWS(src); 159 while (!src.empty) 160 { 161 switch (src.front) 162 { 163 case '?': 164 src.popFront(); 165 166 auto left = compileComparison(src, ctx); 167 168 skipWS(src); 169 if (src.front != ':') 170 { 171 throw new ExpressionError(src, "syntax error, incomplete operator '?:' detected"); 172 } 173 src.popFront(); 174 175 arg = new Tertiary!(V, "?")(arg, left, compileComparison(src, ctx)); 176 177 break; 178 default: 179 return arg; 180 } 181 skipWS(src); 182 } 183 return arg; 184 } 185 186 Node!V compileComparison(R, V)(ref Source!R src, ref Context!V ctx) 187 { 188 auto arg = compileAdd(src, ctx); 189 skipWS(src); 190 while (!src.empty) 191 { 192 switch (src.front) 193 { 194 case '=': 195 src.popFront(); 196 197 if (src.front == '=') 198 { 199 src.popFront(); 200 arg = new Binary!(V, "==")(arg, compileAdd(src, ctx)); 201 } 202 else 203 { 204 throw new ExpressionError(src, "syntax error, incomplete operator '==' detected"); 205 } 206 207 break; 208 case '!': 209 src.popFront(); 210 211 if (src.front == '=') 212 { 213 src.popFront(); 214 arg = new Binary!(V, "!=")(arg, compileAdd(src, ctx)); 215 } 216 else 217 { 218 throw new ExpressionError(src, "syntax error, incomplete operator '!=' detected"); 219 } 220 break; 221 case '>': 222 src.popFront(); 223 224 if (src.front == '=') 225 { 226 src.popFront(); 227 arg = new Binary!(V, ">=")(arg, compileAdd(src, ctx)); 228 } 229 else 230 { 231 arg = new Binary!(V, ">")(arg, compileAdd(src, ctx)); 232 } 233 234 break; 235 case '<': 236 src.popFront(); 237 238 if (src.front == '=') 239 { 240 src.popFront(); 241 arg = new Binary!(V, "<=")(arg, compileAdd(src, ctx)); 242 } 243 else 244 { 245 arg = new Binary!(V, "<")(arg, compileAdd(src, ctx)); 246 } 247 248 break; 249 default: 250 return arg; 251 } 252 skipWS(src); 253 } 254 return arg; 255 } 256 257 Node!V compileAdd(R, V)(ref Source!R src, ref Context!V ctx) 258 { 259 auto arg = compileMul(src, ctx); 260 skipWS(src); 261 while (!src.empty) 262 { 263 switch (src.front) 264 { 265 case '+': 266 src.popFront(); 267 arg = new Binary!(V, "+")(arg, compileMul(src, ctx)); 268 break; 269 case '-': 270 src.popFront(); 271 arg = new Binary!(V, "-")(arg, compileMul(src, ctx)); 272 break; 273 default: 274 return arg; 275 } 276 skipWS(src); 277 } 278 return arg; 279 } 280 281 Node!V compileMul(R, V)(ref Source!R src, ref Context!V ctx) 282 { 283 auto arg = compileValue(src, ctx); 284 skipWS(src); 285 while (!src.empty) 286 { 287 switch (src.front) 288 { 289 case '*': 290 src.popFront(); 291 arg = new Binary!(V, "*")(arg, compileValue(src, ctx)); 292 break; 293 case '/': 294 src.popFront(); 295 arg = new Binary!(V, "/")(arg, compileValue(src, ctx)); 296 break; 297 default: 298 return arg; 299 } 300 skipWS(src); 301 } 302 return arg; 303 } 304 305 Node!V compileValue(R, V)(ref Source!R src, ref Context!V ctx) 306 { 307 skipWS(src); 308 if (src.empty) 309 throw new ExpressionError(src, "unexpected end of expression"); 310 if (isNumber(src.front)) 311 { 312 // literal 313 return new Literal!V(parse!V(src)); 314 } 315 if (isAlpha(src.front)) 316 { 317 // identifier 318 dchar[] result; 319 while (!src.empty && isAlphaNum(src.front)) 320 { 321 result ~= src.front; 322 src.popFront(); 323 } 324 string identifier = result.toUTF8(); 325 326 skipWS(src); 327 if (!src.empty && src.front == '(') 328 { 329 // function 330 Node!V[] args; 331 src.popFront(); 332 skipWS(src); 333 while (!src.empty) 334 { 335 if (src.front == ')') 336 { 337 src.popFront(); 338 return ctx.defineFunction(identifier, args); 339 } 340 if (args.length) 341 { 342 if (src.front != ',') 343 throw new ExpressionError(src, "comma expected"); 344 src.popFront(); 345 } 346 args ~= compileExpr(src, ctx); 347 skipWS(src); 348 } 349 throw new ExpressionError(src, "unexpected end of expression"); 350 } 351 else 352 { 353 // variable 354 return ctx.defineVariable(identifier); 355 } 356 } 357 if (src.front == '-') 358 { 359 // unary minus 360 src.popFront(); 361 return new Unary!(V, "-")(compileValue(src, ctx)); 362 } 363 if (src.front == '(') 364 { 365 src.popFront(); 366 auto expr = compileExpr(src, ctx); 367 skipWS(src); 368 if (src.front != ')') 369 throw new ExpressionError(src, "closing parenthesis expected"); 370 src.popFront(); 371 return expr; 372 } 373 throw new ExpressionError(src, "value expected"); 374 } 375 376 void skipWS(R)(ref Source!R src) 377 { 378 while (!src.empty && isWhite(src.front)) 379 src.popFront(); 380 } 381 382 struct Source(R) 383 { 384 import std.array; 385 386 R m_src; 387 int m_pos = 0; 388 389 bool empty() const 390 { 391 return m_src.empty; 392 } 393 394 dchar front() const 395 { 396 return m_src.front; 397 } 398 399 void popFront() 400 { 401 m_pos++; 402 m_src.popFront(); 403 } 404 405 size_t pos() const 406 { 407 return m_pos; 408 } 409 } 410 411 interface Node(V) 412 { 413 alias Node = .Node!V; 414 V opCall() const; 415 const(Node)[] children() const; 416 string toString() const; 417 } 418 419 class Literal(V) : Node!V 420 { 421 protected V m_val; 422 this(V value) 423 { 424 m_val = value; 425 } 426 427 V opCall() const 428 { 429 return m_val; 430 } 431 432 const(Node)[] children() const 433 { 434 return null; 435 } 436 437 override string toString() const 438 { 439 return "Literal(" ~ m_val.to!string ~ ")"; 440 } 441 } 442 443 unittest 444 { 445 Node!float node = new Literal!float(1.2); 446 assert(node().isClose(1.2)); 447 assert(node.children is null); 448 } 449 450 class Unary(V, string op) : Node!V 451 { 452 protected Node m_arg; 453 this(Node a) 454 { 455 m_arg = a; 456 } 457 458 V opCall() const 459 { 460 return mixin(op ~ `m_arg()`); 461 } 462 463 const(Node)[] children() const 464 { 465 return [m_arg]; 466 } 467 468 override string toString() const 469 { 470 return "Unary(" ~ op ~ ")"; 471 } 472 } 473 474 class Binary(V, string op) : Node!V 475 { 476 protected Node[2] m_args; 477 this(Node a, Node b) 478 { 479 m_args = [a, b]; 480 } 481 482 V opCall() const 483 { 484 V result = mixin(`m_args[0]() ` ~ op ~ ` m_args[1]()`); 485 version (AllowDivisionBy0) 486 { 487 } 488 else 489 { 490 static if (isFloatingPoint!V && op == "/") 491 { 492 if (result.isInfinity) 493 throw new ExpressionError("division by zero"); 494 } 495 } 496 return result; 497 } 498 499 const(Node)[] children() const 500 { 501 return m_args; 502 } 503 504 override string toString() const 505 { 506 return "Binary(" ~ op ~ ")"; 507 } 508 } 509 510 class Tertiary(V, string op) : Node!V 511 { 512 protected Node[3] m_args; 513 this(Node cond, Node success, Node fail) 514 { 515 m_args = [cond, success, fail]; 516 } 517 518 V opCall() const 519 { 520 V test = mixin(`m_args[0]()`); 521 522 if (test > 0) 523 { 524 return mixin(`m_args[1]()`); 525 } 526 else 527 { 528 return mixin(`m_args[2]()`); 529 } 530 } 531 532 const(Node)[] children() const 533 { 534 return m_args; 535 } 536 537 override string toString() const 538 { 539 return "Tertiary(" ~ op ~ ")"; 540 } 541 } 542 543 class Variable(V) : Node!V 544 { 545 protected 546 { 547 string m_name; 548 Nullable!V m_value; 549 } 550 this(string name) 551 { 552 m_name = name; 553 } 554 555 V opCall() const 556 { 557 return m_value.get(); 558 } 559 560 string name() const 561 { 562 return m_name; 563 } 564 565 void set(V value) 566 { 567 m_value = value; 568 } 569 570 bool isSet() const 571 { 572 return !m_value.isNull; 573 } 574 575 const(Node)[] children() const 576 { 577 return null; 578 } 579 580 override string toString() const 581 { 582 return "Variable(" ~ m_name ~ ")"; 583 } 584 } 585 586 class Function(V) : Node!V 587 { 588 protected 589 { 590 string m_name; 591 V delegate() m_fn; 592 Node[] m_args; 593 } 594 595 this(string name, Node[] args...) 596 { 597 m_name = name; 598 m_args = args; 599 } 600 601 size_t arity() const 602 { 603 return m_args.length; 604 } 605 606 void set(F)(F fn) 607 { 608 enum fna = .arity!F; 609 assert(fna == m_args.length, format!"%s wrong arity: %s/%s"(m_name, fna, m_args.length)); 610 enum params = iota(0, fna).map!(i => format!"m_args[%s]()"(i)).join(","); 611 m_fn = () => mixin("fn(" ~ params ~ ")"); 612 } 613 614 V opCall() const 615 { 616 return m_fn(); 617 } 618 619 string name() const 620 { 621 return m_name; 622 } 623 624 bool isSet() const 625 { 626 return m_fn !is null; 627 } 628 629 const(Node)[] children() const 630 { 631 return m_args; 632 } 633 634 override string toString() const 635 { 636 return format!"Function(%s)"(arity); 637 } 638 } 639 640 version (unittest) import std.math : isClose, sqrt; 641 642 unittest 643 { 644 float result = { 645 float a = 3.2; 646 float b = 9.6; 647 auto sqrt = (float x) => .sqrt(x); 648 auto dist1 = (float dx) => .sqrt(dx * dx + dx * dx); 649 auto dist2 = (float dx, float dy) => .sqrt(dx * dx + dy * dy); 650 return dist2(sqrt(a), b) * (a + 5.8) / 3 + b / a - (235.6 + 3 * b) / -2.5 + dist1(a); 651 }(); 652 653 auto expr(string source) 654 { 655 auto e = compileExpression(source); 656 e["a"] = 3.2; 657 e.b = 9.6; 658 e["sqrt"] = (float x) => .sqrt(x); 659 e.dist = (float dx) => .sqrt(dx * dx + dx * dx); 660 e.dist = (float dx, float dy) => .sqrt(dx * dx + dy * dy); 661 return e(); 662 } 663 664 assert(expr("dist(sqrt(a),b)*(a+5.8)/3+b/a-(235.6+3*b)/-2.5+dist(a)").isClose(result)); 665 assert(expr("dist ( sqrt ( a ) , b ) * ( a + 5.8 ) / 3 + b / a - ( 235.6 + 3 * b) / - 2.5 + dist ( a )") 666 .isClose(result)); 667 } 668 669 unittest 670 { 671 672 int result = { 673 int a = 3; 674 int b = 9; 675 auto sqr = (int x) => x * x; 676 return ((a - 3) * b + 1) * (a + 5) / 3 + b / a - (235 + 3 * b) / -2 + sqr(a); 677 }(); 678 679 int expr(string source) 680 { 681 auto e = compileExpression!int(source); 682 e["a"] = 3; 683 e.b = 9; 684 e["sqr"] = (int x) => x * x; 685 return e(); 686 } 687 688 assert(expr("((a - 3) * b + 1) * (a + 5)/3 + b/a - (235 + 3 * b) / -2 + sqr(a)") == result); 689 } 690 691 unittest 692 { 693 int expr(string source) 694 { 695 auto e = compileExpression!int(source); 696 e["a"] = 3; 697 e.b = 9; 698 return e(); 699 } 700 701 assert(expr("a < b") == 1); 702 assert(expr("a <= b") == 1); 703 assert(expr("a != b") == 1); 704 assert(expr("b > a") == 1); 705 assert(expr("b >= a") == 1); 706 assert(expr("a > b") == 0); 707 assert(expr("a >= b") == 0); 708 assert(expr("a == b") == 0); 709 assert(expr("b <= a") == 0); 710 assert(expr("b != a") == 1); 711 } 712 713 unittest 714 { 715 int expr(string source) 716 { 717 auto e = compileExpression!int(source); 718 e["a"] = 3; 719 e.b = 9; 720 return e(); 721 } 722 723 assert(expr("a < b ? a : b") == 3); 724 assert(expr("a <= b ? b : a") == 9); 725 assert(expr("a != b ? 0 : 1") == 0); 726 assert(expr("b > a ? 2 : 1") == 2); 727 }