1 module expression;
2 
3 import std.algorithm : map, find;
4 import std.conv : parse, to;
5 import std.format : formattedWrite;
6 import std.math : isInfinity;
7 import std.meta : AliasSeq;
8 import std.range : ElementType, iota, isInputRange, repeat;
9 import std.string : format, join;
10 import std.traits : arity, isSomeChar, isFloatingPoint;
11 import std.typecons : Nullable;
12 import std.uni : isAlpha, isAlphaNum, isNumber, isWhite;
13 import std.utf : toUTF8;
14 
15 class ExpressionError : Exception
16 {
17     this(string msg, string file = __FILE__, size_t line = __LINE__)
18     {
19         super(msg, file, line);
20     }
21 
22     this(S)(ref const S source, string msg, string file = __FILE__, size_t line = __LINE__)
23     {
24         super("%s (%s)".format(msg, source.pos + 1), file, line);
25     }
26 }
27 
28 struct Expression(V)
29 {
30     @disable this();
31 
32     this(R)(R source) if (isInputRange!R && isSomeChar!(ElementType!R))
33     {
34         auto src = Source!R(source);
35         m_root = compileTertiary(src, m_context);
36         skipWS(src);
37         if (!src.empty())
38         {
39             throw new ExpressionError(src, "syntax error, unexpected '" ~ src.front.to!string ~ "'");
40         }
41     }
42 
43     V opCall()
44     {
45         validate();
46         return m_root();
47     }
48 
49     void toString(W)(ref W writer) const
50     {
51         void writeNode(const Node!V node, string indent = "")
52         {
53             writer.formattedWrite("%s%s\n", indent, node);
54             foreach (child; node.children)
55                 writeNode(child, indent ~ "  ");
56         }
57 
58         writeNode(m_root);
59     }
60 
61     void opIndexAssign(V value, string name)
62     {
63         auto p = name in m_context.variables;
64         if (p is null)
65             throw new ExpressionError("undefined variable: " ~ name);
66         (*p).set(value);
67     }
68 
69     void opIndexAssign(F)(F fn, string name)
70     {
71         // check function type;
72         mixin("alias VS = AliasSeq!(" ~ "V.init".repeat(arity!F).join(",") ~ ");");
73         V v;
74         static assert(__traits(compiles, v = fn(VS)), "invalid function type:" ~ F.stringof);
75 
76         auto p = (name ~= arity!F.to!string) in m_context.functions;
77         if (p is null)
78             throw new ExpressionError(format!"undefined function (%s/%s)"(name, arity!F));
79         (*p).set(fn);
80     }
81 
82     template opDispatch(string name)
83     {
84         void opDispatch(V)(V value)
85         {
86             this[name] = value;
87         }
88     }
89 
90     auto variables()
91     {
92         return m_context.variables.keys;
93     }
94 
95     auto functions()
96     {
97         return m_context.functions.keys;
98     }
99 
100 private:
101     Node!V m_root;
102     Context!V m_context;
103     bool m_validated = false;
104 
105     void validate()
106     {
107         if (!m_validated)
108         {
109             foreach (var; m_context.variables.values)
110             {
111                 if (!var.isSet)
112                     throw new ExpressionError("uninitialized variable: " ~ var.name);
113             }
114             foreach (fn; m_context.functions.values)
115             {
116                 if (!fn.isSet)
117                     throw new ExpressionError("uninitialized function: " ~ fn.name);
118             }
119             m_validated = true;
120         }
121     }
122 }
123 
124 Expression!V compileExpression(V = float, R)(R source)
125 {
126     return Expression!V(source);
127 }
128 
129 private:
130 
131 struct Context(V)
132 {
133     Variable!V[string] variables;
134     Function!V[string] functions;
135     auto defineVariable(string identifier)
136     {
137         auto p = identifier in variables;
138         if (p !is null)
139             return *p;
140         return variables[identifier] = new Variable!V(identifier);
141     }
142 
143     auto defineFunction(string identifier, Node!V[] args...)
144     {
145         identifier ~= args.length.to!string;
146         auto p = identifier in functions;
147         if (p !is null)
148             return *p;
149         return functions[identifier] = new Function!V(identifier, args);
150     }
151 }
152 
153 alias compileExpr = compileTertiary;
154 
155 Node!V compileTertiary(R, V)(ref Source!R src, ref Context!V ctx)
156 {
157     auto arg = compileComparison(src, ctx);
158     skipWS(src);
159     while (!src.empty)
160     {
161         switch (src.front)
162         {
163         case '?':
164             src.popFront();
165 
166             auto left = compileComparison(src, ctx);
167 
168             skipWS(src);
169             if (src.front != ':')
170             {
171                 throw new ExpressionError(src, "syntax error, incomplete operator '?:' detected");
172             }
173             src.popFront();
174 
175             arg = new Tertiary!(V, "?")(arg, left, compileComparison(src, ctx));
176 
177             break;
178         default:
179             return arg;
180         }
181         skipWS(src);
182     }
183     return arg;
184 }
185 
186 Node!V compileComparison(R, V)(ref Source!R src, ref Context!V ctx)
187 {
188     auto arg = compileAdd(src, ctx);
189     skipWS(src);
190     while (!src.empty)
191     {
192         switch (src.front)
193         {
194         case '=':
195             src.popFront();
196 
197             if (src.front == '=')
198             {
199                 src.popFront();
200                 arg = new Binary!(V, "==")(arg, compileAdd(src, ctx));
201             }
202             else
203             {
204                 throw new ExpressionError(src, "syntax error, incomplete operator '==' detected");
205             }
206 
207             break;
208         case '!':
209             src.popFront();
210 
211             if (src.front == '=')
212             {
213                 src.popFront();
214                 arg = new Binary!(V, "!=")(arg, compileAdd(src, ctx));
215             }
216             else
217             {
218                 throw new ExpressionError(src, "syntax error, incomplete operator '!=' detected");
219             }
220             break;
221         case '>':
222             src.popFront();
223 
224             if (src.front == '=')
225             {
226                 src.popFront();
227                 arg = new Binary!(V, ">=")(arg, compileAdd(src, ctx));
228             }
229             else
230             {
231                 arg = new Binary!(V, ">")(arg, compileAdd(src, ctx));
232             }
233 
234             break;
235         case '<':
236             src.popFront();
237 
238             if (src.front == '=')
239             {
240                 src.popFront();
241                 arg = new Binary!(V, "<=")(arg, compileAdd(src, ctx));
242             }
243             else
244             {
245                 arg = new Binary!(V, "<")(arg, compileAdd(src, ctx));
246             }
247 
248             break;
249         default:
250             return arg;
251         }
252         skipWS(src);
253     }
254     return arg;
255 }
256 
257 Node!V compileAdd(R, V)(ref Source!R src, ref Context!V ctx)
258 {
259     auto arg = compileMul(src, ctx);
260     skipWS(src);
261     while (!src.empty)
262     {
263         switch (src.front)
264         {
265         case '+':
266             src.popFront();
267             arg = new Binary!(V, "+")(arg, compileMul(src, ctx));
268             break;
269         case '-':
270             src.popFront();
271             arg = new Binary!(V, "-")(arg, compileMul(src, ctx));
272             break;
273         default:
274             return arg;
275         }
276         skipWS(src);
277     }
278     return arg;
279 }
280 
281 Node!V compileMul(R, V)(ref Source!R src, ref Context!V ctx)
282 {
283     auto arg = compileValue(src, ctx);
284     skipWS(src);
285     while (!src.empty)
286     {
287         switch (src.front)
288         {
289         case '*':
290             src.popFront();
291             arg = new Binary!(V, "*")(arg, compileValue(src, ctx));
292             break;
293         case '/':
294             src.popFront();
295             arg = new Binary!(V, "/")(arg, compileValue(src, ctx));
296             break;
297         default:
298             return arg;
299         }
300         skipWS(src);
301     }
302     return arg;
303 }
304 
305 Node!V compileValue(R, V)(ref Source!R src, ref Context!V ctx)
306 {
307     skipWS(src);
308     if (src.empty)
309         throw new ExpressionError(src, "unexpected end of expression");
310     if (isNumber(src.front))
311     {
312         // literal
313         return new Literal!V(parse!V(src));
314     }
315     if (isAlpha(src.front))
316     {
317         // identifier
318         dchar[] result;
319         while (!src.empty && isAlphaNum(src.front))
320         {
321             result ~= src.front;
322             src.popFront();
323         }
324         string identifier = result.toUTF8();
325 
326         skipWS(src);
327         if (!src.empty && src.front == '(')
328         {
329             // function
330             Node!V[] args;
331             src.popFront();
332             skipWS(src);
333             while (!src.empty)
334             {
335                 if (src.front == ')')
336                 {
337                     src.popFront();
338                     return ctx.defineFunction(identifier, args);
339                 }
340                 if (args.length)
341                 {
342                     if (src.front != ',')
343                         throw new ExpressionError(src, "comma expected");
344                     src.popFront();
345                 }
346                 args ~= compileExpr(src, ctx);
347                 skipWS(src);
348             }
349             throw new ExpressionError(src, "unexpected end of expression");
350         }
351         else
352         {
353             // variable
354             return ctx.defineVariable(identifier);
355         }
356     }
357     if (src.front == '-')
358     {
359         // unary minus
360         src.popFront();
361         return new Unary!(V, "-")(compileValue(src, ctx));
362     }
363     if (src.front == '(')
364     {
365         src.popFront();
366         auto expr = compileExpr(src, ctx);
367         skipWS(src);
368         if (src.front != ')')
369             throw new ExpressionError(src, "closing parenthesis expected");
370         src.popFront();
371         return expr;
372     }
373     throw new ExpressionError(src, "value expected");
374 }
375 
376 void skipWS(R)(ref Source!R src)
377 {
378     while (!src.empty && isWhite(src.front))
379         src.popFront();
380 }
381 
382 struct Source(R)
383 {
384     import std.array;
385 
386     R m_src;
387     int m_pos = 0;
388 
389     bool empty() const
390     {
391         return m_src.empty;
392     }
393 
394     dchar front() const
395     {
396         return m_src.front;
397     }
398 
399     void popFront()
400     {
401         m_pos++;
402         m_src.popFront();
403     }
404 
405     size_t pos() const
406     {
407         return m_pos;
408     }
409 }
410 
411 interface Node(V)
412 {
413     alias Node = .Node!V;
414     V opCall() const;
415     const(Node)[] children() const;
416     string toString() const;
417 }
418 
419 class Literal(V) : Node!V
420 {
421     protected V m_val;
422     this(V value)
423     {
424         m_val = value;
425     }
426 
427     V opCall() const
428     {
429         return m_val;
430     }
431 
432     const(Node)[] children() const
433     {
434         return null;
435     }
436 
437     override string toString() const
438     {
439         return "Literal(" ~ m_val.to!string ~ ")";
440     }
441 }
442 
443 unittest
444 {
445     Node!float node = new Literal!float(1.2);
446     assert(node().isClose(1.2));
447     assert(node.children is null);
448 }
449 
450 class Unary(V, string op) : Node!V
451 {
452     protected Node m_arg;
453     this(Node a)
454     {
455         m_arg = a;
456     }
457 
458     V opCall() const
459     {
460         return mixin(op ~ `m_arg()`);
461     }
462 
463     const(Node)[] children() const
464     {
465         return [m_arg];
466     }
467 
468     override string toString() const
469     {
470         return "Unary(" ~ op ~ ")";
471     }
472 }
473 
474 class Binary(V, string op) : Node!V
475 {
476     protected Node[2] m_args;
477     this(Node a, Node b)
478     {
479         m_args = [a, b];
480     }
481 
482     V opCall() const
483     {
484         V result = mixin(`m_args[0]() ` ~ op ~ ` m_args[1]()`);
485         version (AllowDivisionBy0)
486         {
487         }
488         else
489         {
490             static if (isFloatingPoint!V && op == "/")
491             {
492                 if (result.isInfinity)
493                     throw new ExpressionError("division by zero");
494             }
495         }
496         return result;
497     }
498 
499     const(Node)[] children() const
500     {
501         return m_args;
502     }
503 
504     override string toString() const
505     {
506         return "Binary(" ~ op ~ ")";
507     }
508 }
509 
510 class Tertiary(V, string op) : Node!V
511 {
512     protected Node[3] m_args;
513     this(Node cond, Node success, Node fail)
514     {
515         m_args = [cond, success, fail];
516     }
517 
518     V opCall() const
519     {
520         V test = mixin(`m_args[0]()`);
521 
522         if (test > 0)
523         {
524             return mixin(`m_args[1]()`);
525         }
526         else
527         {
528             return mixin(`m_args[2]()`);
529         }
530     }
531 
532     const(Node)[] children() const
533     {
534         return m_args;
535     }
536 
537     override string toString() const
538     {
539         return "Tertiary(" ~ op ~ ")";
540     }
541 }
542 
543 class Variable(V) : Node!V
544 {
545     protected
546     {
547         string m_name;
548         Nullable!V m_value;
549     }
550     this(string name)
551     {
552         m_name = name;
553     }
554 
555     V opCall() const
556     {
557         return m_value.get();
558     }
559 
560     string name() const
561     {
562         return m_name;
563     }
564 
565     void set(V value)
566     {
567         m_value = value;
568     }
569 
570     bool isSet() const
571     {
572         return !m_value.isNull;
573     }
574 
575     const(Node)[] children() const
576     {
577         return null;
578     }
579 
580     override string toString() const
581     {
582         return "Variable(" ~ m_name ~ ")";
583     }
584 }
585 
586 class Function(V) : Node!V
587 {
588     protected
589     {
590         string m_name;
591         V delegate() m_fn;
592         Node[] m_args;
593     }
594 
595     this(string name, Node[] args...)
596     {
597         m_name = name;
598         m_args = args;
599     }
600 
601     size_t arity() const
602     {
603         return m_args.length;
604     }
605 
606     void set(F)(F fn)
607     {
608         enum fna = .arity!F;
609         assert(fna == m_args.length, format!"%s wrong arity: %s/%s"(m_name, fna, m_args.length));
610         enum params = iota(0, fna).map!(i => format!"m_args[%s]()"(i)).join(",");
611         m_fn = () => mixin("fn(" ~ params ~ ")");
612     }
613 
614     V opCall() const
615     {
616         return m_fn();
617     }
618 
619     string name() const
620     {
621         return m_name;
622     }
623 
624     bool isSet() const
625     {
626         return m_fn !is null;
627     }
628 
629     const(Node)[] children() const
630     {
631         return m_args;
632     }
633 
634     override string toString() const
635     {
636         return format!"Function(%s)"(arity);
637     }
638 }
639 
640 version (unittest) import std.math : isClose, sqrt;
641 
642 unittest
643 {
644     float result = {
645         float a = 3.2;
646         float b = 9.6;
647         auto sqrt = (float x) => .sqrt(x);
648         auto dist1 = (float dx) => .sqrt(dx * dx + dx * dx);
649         auto dist2 = (float dx, float dy) => .sqrt(dx * dx + dy * dy);
650         return dist2(sqrt(a), b) * (a + 5.8) / 3 + b / a - (235.6 + 3 * b) / -2.5 + dist1(a);
651     }();
652 
653     auto expr(string source)
654     {
655         auto e = compileExpression(source);
656         e["a"] = 3.2;
657         e.b = 9.6;
658         e["sqrt"] = (float x) => .sqrt(x);
659         e.dist = (float dx) => .sqrt(dx * dx + dx * dx);
660         e.dist = (float dx, float dy) => .sqrt(dx * dx + dy * dy);
661         return e();
662     }
663 
664     assert(expr("dist(sqrt(a),b)*(a+5.8)/3+b/a-(235.6+3*b)/-2.5+dist(a)").isClose(result));
665     assert(expr("dist ( sqrt ( a ) , b ) * ( a + 5.8 ) / 3 + b / a - ( 235.6 + 3 * b) / - 2.5 + dist ( a )")
666             .isClose(result));
667 }
668 
669 unittest
670 {
671 
672     int result = {
673         int a = 3;
674         int b = 9;
675         auto sqr = (int x) => x * x;
676         return ((a - 3) * b + 1) * (a + 5) / 3 + b / a - (235 + 3 * b) / -2 + sqr(a);
677     }();
678 
679     int expr(string source)
680     {
681         auto e = compileExpression!int(source);
682         e["a"] = 3;
683         e.b = 9;
684         e["sqr"] = (int x) => x * x;
685         return e();
686     }
687 
688     assert(expr("((a - 3) * b + 1) * (a + 5)/3 + b/a - (235 + 3 * b) / -2 + sqr(a)") == result);
689 }
690 
691 unittest
692 {
693     int expr(string source)
694     {
695         auto e = compileExpression!int(source);
696         e["a"] = 3;
697         e.b = 9;
698         return e();
699     }
700 
701     assert(expr("a < b") == 1);
702     assert(expr("a <= b") == 1);
703     assert(expr("a != b") == 1);
704     assert(expr("b > a") == 1);
705     assert(expr("b >= a") == 1);
706     assert(expr("a > b") == 0);
707     assert(expr("a >= b") == 0);
708     assert(expr("a == b") == 0);
709     assert(expr("b <= a") == 0);
710     assert(expr("b != a") == 1);
711 }
712 
713 unittest
714 {
715     int expr(string source)
716     {
717         auto e = compileExpression!int(source);
718         e["a"] = 3;
719         e.b = 9;
720         return e();
721     }
722 
723     assert(expr("a < b ? a : b") == 3);
724     assert(expr("a <= b ? b : a") == 9);
725     assert(expr("a != b ? 0 : 1") == 0);
726     assert(expr("b > a ? 2 : 1") == 2);
727 }