1 /// Author: Aziz Köksal
2 /// License: GPL3
3 /// $(Maturity very high)
4 module dil.ast.Visitor;
5 
6 import dil.ast.Node,
7        dil.ast.Declarations,
8        dil.ast.Expressions,
9        dil.ast.Statements,
10        dil.ast.Types,
11        dil.ast.Parameters;
12 import common;
13 
14 /// Generates visit methods for all classes.
15 ///
16 /// E.g.:
17 /// ---
18 /// Declaration visit(ClassDecl node){return unhandled(node);}
19 /// Expression visit(CommaExpr node){return unhandled(node);}
20 /// ---
21 char[] generateVisitMethods()
22 {
23   char[] text = "void _beforeFirstVisitMethod(){}".dup;
24   foreach (className; NodeClassNames)
25     text ~= "returnType!("~className~") visit("~className~" node)"~
26             "{return unhandled(node).to!("~className~");}\n";
27   return text;
28 }
29 
30 /// Same as generateVisitMethods, but return void instead.
31 char[] generateVisitMethods2()
32 {
33   char[] text = "void _beforeFirstVisitMethod(){}".dup;
34   foreach (className; NodeClassNames)
35     text ~= "void visit("~className~" node){unhandled(node);}\n";
36   return text;
37 }
38 
39 /// Gets the appropriate return type for the provided class.
40 template returnType(Class)
41 {
42   static if (is(Class : Declaration))
43     alias returnType = Declaration;
44   else
45   static if (is(Class : Statement))
46     alias returnType = Statement;
47   else
48   static if (is(Class : Expression))
49     alias returnType = Expression;
50   else
51   static if (is(Class : TypeNode))
52     alias returnType = TypeNode;
53   else
54     alias returnType = Node;
55 }
56 
57 /// Calls the visitor method that can handle node n.
58 Ret callVisitMethod(Ret)(Object visitorInstance, Node n, NodeKind k)
59 { // Get the method's address from the vtable.
60   const funcIndex = indexOfFirstVisitMethod + k;
61   auto funcptr = typeid(visitorInstance).vtbl[funcIndex];
62   // Construct a delegate and call it.
63   Ret delegate(Node) visitMethod = void;
64   visitMethod.ptr = cast(void*)visitorInstance; // Frame pointer.
65   visitMethod.funcptr = cast(Ret function(Node))funcptr;
66   return visitMethod(n);
67 }
68 
69 /// Implements a variation of the visitor pattern.
70 ///
71 /// Inherited by classes that need to traverse a D syntax tree
72 /// and do computations, transformations and other things on it.
73 abstract class Visitor
74 {
75   mixin(generateVisitMethods());
76 
77   /// Calls the appropriate visit() method for a node.
78   Node dispatch(Node n)
79   {
80     return callVisitMethod!(Node)(this, n, n.kind);
81   }
82 
83   /// Allows calling the visit() method with a null node.
84   Node dispatch(Node n, NodeKind k)
85   {
86     return callVisitMethod!(Node)(this, n, k);
87   }
88 
89   /// Called by visit() methods that were not overridden.
90   Node unhandled(Node n)
91   { return n; }
92 
93 final:
94   // Visits a Declaration and returns a Declaration.
95   Declaration visitD(Declaration n)
96   {
97     return dispatch(n).to!(Declaration);
98   }
99   // Visits a Statement and returns a Statement.
100   Statement visitS(Statement n)
101   {
102     return dispatch(n).to!(Statement);
103   }
104   // Visits a Expression and returns an Expression.
105   Expression visitE(Expression n)
106   {
107     return dispatch(n).to!(Expression);
108   }
109   // Visits a TypeNode and returns a TypeNode.
110   TypeNode visitT(TypeNode n)
111   {
112     return dispatch(n).to!(TypeNode);
113   }
114   // Visits a Node and returns a Node.
115   Node visitN(Node n)
116   {
117     return dispatch(n);
118   }
119 }
120 
121 /// The same as class Visitor, but the methods return void.
122 /// This class is suitable when you don't want to transform the AST.
123 abstract class Visitor2
124 {
125   mixin(generateVisitMethods2());
126 
127   /// Calls the appropriate visit() method for a node.
128   void dispatch(Node n)
129   {
130     callVisitMethod!(void)(this, n, n.kind);
131   }
132 
133   /// Allows calling the visit() method with a null node.
134   void dispatch(Node n, NodeKind k)
135   {
136     callVisitMethod!(void)(this, n, k);
137   }
138 
139   /// Called by visit() methods that were not overridden.
140   void unhandled(Node n)
141   {}
142 
143 final:
144   alias visit = dispatch,
145        visitN = dispatch,
146        visitD = dispatch,
147        visitS = dispatch,
148        visitE = dispatch,
149        visitT = dispatch;
150 }
151 
152 /// Index into the vtable of the Visitor classes.
153 private static const size_t indexOfFirstVisitMethod;
154 
155 /// Initializes indexOfFirstVisitMethod for both Visitor classes.
156 static this()
157 {
158   auto vtbl = typeid(Visitor).vtbl;
159   auto vtbl2 = typeid(Visitor2).vtbl;
160   assert(vtbl.length == vtbl2.length);
161   size_t i;
162   foreach (j, func; vtbl)
163     if (func is &Visitor._beforeFirstVisitMethod &&
164         vtbl2[j] is &Visitor2._beforeFirstVisitMethod)
165     {
166       i = j + 1;
167       assert(i < vtbl.length);
168       break;
169     }
170   assert(i, "couldn't find first visit method in the vtable");
171   indexOfFirstVisitMethod = i;
172 }
173 
174 void testVisitor()
175 {
176   scope msg = new UnittestMsg("Testing class Visitor.");
177 
178   class TestVisitor : Visitor
179   {
180     alias visit = super.visit;
181     override Expression visit(NullExpr e)
182     {
183       return e;
184     }
185   }
186 
187   class TestVisitor2 : Visitor2
188   {
189     NullExpr ie;
190     alias visit = super.visit;
191     override void visit(NullExpr e)
192     {
193       ie = e;
194     }
195   }
196 
197   auto ie = new NullExpr();
198   auto v1 = new TestVisitor();
199   auto v2 = new TestVisitor2();
200 
201   assert(v1.visit(ie) is ie, "Visitor.visit(IdentifierExpr) was not called");
202   v2.visit(ie);
203   assert(v2.ie is ie, "Visitor2.visit(IdentifierExpr) was not called");
204 }