1"""
206_ast_visitor.py - AST Node Classes and Visitor Pattern
3
4Demonstrates the Visitor design pattern applied to an Abstract Syntax Tree.
5The visitor pattern separates tree traversal logic from node definitions,
6making it easy to add new operations without modifying node classes.
7
8AST nodes are defined using Python dataclasses.
9Three visitors are implemented:
10 1. EvalVisitor - evaluates the expression tree to a numeric value
11 2. TypeCheckVisitor - infers and checks types (int, float, bool, str)
12 3. PrettyPrintVisitor - produces a formatted source code string
13
14The language supports:
15 - Numeric literals (int and float)
16 - Boolean literals
17 - String literals
18 - Arithmetic: +, -, *, /, % (unary -)
19 - Comparison: ==, !=, <, >, <=, >=
20 - Logical: and, or, not
21 - String concatenation with +
22 - Variables and let bindings
23 - If expressions (ternary style)
24
25Topics covered:
26 - Dataclass-based AST nodes
27 - Abstract base class for visitors
28 - Double dispatch via visit_<NodeType> method naming
29 - Type inference without annotations
30 - Environment (symbol table) as a dict
31"""
32
33from __future__ import annotations
34from abc import ABC, abstractmethod
35from dataclasses import dataclass, field
36from typing import Any, Optional, Union
37
38
39# ---------------------------------------------------------------------------
40# AST Node definitions
41# ---------------------------------------------------------------------------
42
43class Node:
44 """Base class for all AST nodes."""
45 pass
46
47
48@dataclass
49class IntLit(Node):
50 value: int
51
52@dataclass
53class FloatLit(Node):
54 value: float
55
56@dataclass
57class BoolLit(Node):
58 value: bool
59
60@dataclass
61class StrLit(Node):
62 value: str
63
64@dataclass
65class Var(Node):
66 name: str
67
68@dataclass
69class BinOp(Node):
70 op: str # '+', '-', '*', '/', '%', '==', '!=', '<', '>', '<=', '>=', 'and', 'or'
71 left: Node
72 right: Node
73
74@dataclass
75class UnaryOp(Node):
76 op: str # '-', 'not'
77 operand: Node
78
79@dataclass
80class IfExpr(Node):
81 """Ternary if: condition ? then_expr : else_expr"""
82 condition: Node
83 then_expr: Node
84 else_expr: Node
85
86@dataclass
87class LetExpr(Node):
88 """let name = value in body"""
89 name: str
90 value: Node
91 body: Node
92
93@dataclass
94class FuncCall(Node):
95 name: str
96 args: list[Node] = field(default_factory=list)
97
98
99# ---------------------------------------------------------------------------
100# Visitor base class
101# ---------------------------------------------------------------------------
102
103class Visitor(ABC):
104 """
105 Abstract visitor. Each concrete visitor implements visit_<NodeType>
106 methods. The dispatch method calls the appropriate visit method.
107 """
108
109 def visit(self, node: Node) -> Any:
110 """Dispatch to the appropriate visit_* method."""
111 method_name = f"visit_{type(node).__name__}"
112 method = getattr(self, method_name, self.generic_visit)
113 return method(node)
114
115 def generic_visit(self, node: Node) -> Any:
116 raise NotImplementedError(
117 f"{type(self).__name__} has no handler for {type(node).__name__}"
118 )
119
120
121# ---------------------------------------------------------------------------
122# Visitor 1: Evaluator
123# ---------------------------------------------------------------------------
124
125class EvalError(Exception):
126 pass
127
128
129class EvalVisitor(Visitor):
130 """
131 Evaluates an AST expression to a Python value.
132 Maintains an environment mapping variable names to values.
133 """
134
135 def __init__(self, env: Optional[dict[str, Any]] = None):
136 self.env: dict[str, Any] = env or {}
137
138 def visit_IntLit(self, node: IntLit) -> int:
139 return node.value
140
141 def visit_FloatLit(self, node: FloatLit) -> float:
142 return node.value
143
144 def visit_BoolLit(self, node: BoolLit) -> bool:
145 return node.value
146
147 def visit_StrLit(self, node: StrLit) -> str:
148 return node.value
149
150 def visit_Var(self, node: Var) -> Any:
151 if node.name not in self.env:
152 raise EvalError(f"Undefined variable: {node.name!r}")
153 return self.env[node.name]
154
155 def visit_BinOp(self, node: BinOp) -> Any:
156 l = self.visit(node.left)
157 # Short-circuit evaluation for 'and' and 'or'
158 if node.op == 'and':
159 return l and self.visit(node.right)
160 if node.op == 'or':
161 return l or self.visit(node.right)
162 r = self.visit(node.right)
163 ops = {
164 '+': lambda a, b: a + b,
165 '-': lambda a, b: a - b,
166 '*': lambda a, b: a * b,
167 '/': lambda a, b: a / b,
168 '%': lambda a, b: a % b,
169 '==': lambda a, b: a == b,
170 '!=': lambda a, b: a != b,
171 '<': lambda a, b: a < b,
172 '>': lambda a, b: a > b,
173 '<=': lambda a, b: a <= b,
174 '>=': lambda a, b: a >= b,
175 }
176 if node.op not in ops:
177 raise EvalError(f"Unknown operator: {node.op!r}")
178 return ops[node.op](l, r)
179
180 def visit_UnaryOp(self, node: UnaryOp) -> Any:
181 v = self.visit(node.operand)
182 if node.op == '-': return -v
183 if node.op == 'not': return not v
184 raise EvalError(f"Unknown unary operator: {node.op!r}")
185
186 def visit_IfExpr(self, node: IfExpr) -> Any:
187 cond = self.visit(node.condition)
188 return self.visit(node.then_expr if cond else node.else_expr)
189
190 def visit_LetExpr(self, node: LetExpr) -> Any:
191 val = self.visit(node.value)
192 old = self.env.get(node.name)
193 self.env[node.name] = val
194 result = self.visit(node.body)
195 # Restore previous binding (lexical scoping)
196 if old is None:
197 self.env.pop(node.name, None)
198 else:
199 self.env[node.name] = old
200 return result
201
202 def visit_FuncCall(self, node: FuncCall) -> Any:
203 args = [self.visit(a) for a in node.args]
204 builtins = {
205 'abs': lambda x: abs(x),
206 'min': lambda *xs: min(xs),
207 'max': lambda *xs: max(xs),
208 'sqrt': lambda x: x ** 0.5,
209 'len': lambda s: len(s),
210 'str': lambda x: str(x),
211 'int': lambda x: int(x),
212 'float': lambda x: float(x),
213 }
214 if node.name in builtins:
215 return builtins[node.name](*args)
216 raise EvalError(f"Unknown function: {node.name!r}")
217
218
219# ---------------------------------------------------------------------------
220# Visitor 2: Type Checker
221# ---------------------------------------------------------------------------
222
223class TypeError_(Exception):
224 pass
225
226
227# Simple type tags
228INT = 'int'
229FLOAT = 'float'
230BOOL = 'bool'
231STR = 'str'
232NUM = 'num' # int or float
233
234
235class TypeCheckVisitor(Visitor):
236 """
237 Infers the type of an expression.
238 Reports TypeError_ for type mismatches.
239 Uses a type environment mapping variable names to types.
240 """
241
242 def __init__(self, type_env: Optional[dict[str, str]] = None):
243 self.type_env: dict[str, str] = type_env or {}
244 self.errors: list[str] = []
245
246 def _error(self, msg: str) -> str:
247 self.errors.append(msg)
248 return 'error'
249
250 def visit_IntLit(self, node: IntLit) -> str: return INT
251 def visit_FloatLit(self, node: FloatLit) -> str: return FLOAT
252 def visit_BoolLit(self, node: BoolLit) -> str: return BOOL
253 def visit_StrLit(self, node: StrLit) -> str: return STR
254
255 def visit_Var(self, node: Var) -> str:
256 if node.name not in self.type_env:
257 return self._error(f"Undefined variable: {node.name!r}")
258 return self.type_env[node.name]
259
260 def _is_numeric(self, t: str) -> bool:
261 return t in (INT, FLOAT)
262
263 def _numeric_result(self, t1: str, t2: str) -> str:
264 """int op int -> int; float op float -> float; int op float -> float."""
265 if t1 == FLOAT or t2 == FLOAT:
266 return FLOAT
267 return INT
268
269 def visit_BinOp(self, node: BinOp) -> str:
270 lt = self.visit(node.left)
271 rt = self.visit(node.right)
272
273 if node.op in ('+', '-', '*', '/', '%'):
274 if node.op == '+' and lt == STR and rt == STR:
275 return STR # string concatenation
276 if self._is_numeric(lt) and self._is_numeric(rt):
277 return self._numeric_result(lt, rt)
278 return self._error(
279 f"Operator {node.op!r} not applicable to {lt} and {rt}"
280 )
281
282 if node.op in ('<', '>', '<=', '>='):
283 if self._is_numeric(lt) and self._is_numeric(rt):
284 return BOOL
285 return self._error(
286 f"Comparison {node.op!r} not applicable to {lt} and {rt}"
287 )
288
289 if node.op in ('==', '!='):
290 if lt == rt or (self._is_numeric(lt) and self._is_numeric(rt)):
291 return BOOL
292 return self._error(f"Cannot compare {lt} with {rt}")
293
294 if node.op in ('and', 'or'):
295 if lt == BOOL and rt == BOOL:
296 return BOOL
297 return self._error(
298 f"Logical {node.op!r} requires bool operands, got {lt} and {rt}"
299 )
300
301 return self._error(f"Unknown operator: {node.op!r}")
302
303 def visit_UnaryOp(self, node: UnaryOp) -> str:
304 t = self.visit(node.operand)
305 if node.op == '-':
306 if self._is_numeric(t): return t
307 return self._error(f"Unary '-' not applicable to {t}")
308 if node.op == 'not':
309 if t == BOOL: return BOOL
310 return self._error(f"'not' requires bool, got {t}")
311 return self._error(f"Unknown unary operator: {node.op!r}")
312
313 def visit_IfExpr(self, node: IfExpr) -> str:
314 ct = self.visit(node.condition)
315 if ct != BOOL:
316 self._error(f"If condition must be bool, got {ct}")
317 tt = self.visit(node.then_expr)
318 et = self.visit(node.else_expr)
319 if tt != et and not (self._is_numeric(tt) and self._is_numeric(et)):
320 self._error(f"If branches have different types: {tt} vs {et}")
321 return tt
322
323 def visit_LetExpr(self, node: LetExpr) -> str:
324 val_type = self.visit(node.value)
325 old = self.type_env.get(node.name)
326 self.type_env[node.name] = val_type
327 body_type = self.visit(node.body)
328 if old is None:
329 self.type_env.pop(node.name, None)
330 else:
331 self.type_env[node.name] = old
332 return body_type
333
334 def visit_FuncCall(self, node: FuncCall) -> str:
335 arg_types = [self.visit(a) for a in node.args]
336 sigs: dict[str, str] = {
337 'abs': INT, 'sqrt': FLOAT, 'len': INT,
338 'str': STR, 'int': INT, 'float': FLOAT,
339 'min': FLOAT, 'max': FLOAT,
340 }
341 return sigs.get(node.name, 'unknown')
342
343
344# ---------------------------------------------------------------------------
345# Visitor 3: Pretty Printer
346# ---------------------------------------------------------------------------
347
348class PrettyPrintVisitor(Visitor):
349 """
350 Produces a readable string representation of the AST.
351 Adds parentheses based on operator precedence.
352 """
353
354 PREC = {
355 'or': 1, 'and': 2,
356 '==': 3, '!=': 3,
357 '<': 4, '>': 4, '<=': 4, '>=': 4,
358 '+': 5, '-': 5,
359 '*': 6, '/': 6, '%': 6,
360 }
361
362 def _prec(self, op: str) -> int:
363 return self.PREC.get(op, 99)
364
365 def visit_IntLit(self, node: IntLit) -> str: return str(node.value)
366 def visit_FloatLit(self, node: FloatLit) -> str: return str(node.value)
367 def visit_BoolLit(self, node: BoolLit) -> str: return 'true' if node.value else 'false'
368 def visit_StrLit(self, node: StrLit) -> str: return f'"{node.value}"'
369 def visit_Var(self, node: Var) -> str: return node.name
370
371 def visit_BinOp(self, node: BinOp) -> str:
372 l = self.visit(node.left)
373 r = self.visit(node.right)
374 # Add parens for sub-expressions with lower precedence
375 if isinstance(node.left, BinOp) and self._prec(node.left.op) < self._prec(node.op):
376 l = f"({l})"
377 if isinstance(node.right, BinOp) and self._prec(node.right.op) <= self._prec(node.op):
378 r = f"({r})"
379 return f"{l} {node.op} {r}"
380
381 def visit_UnaryOp(self, node: UnaryOp) -> str:
382 inner = self.visit(node.operand)
383 if isinstance(node.operand, BinOp):
384 inner = f"({inner})"
385 return f"{node.op}{inner}"
386
387 def visit_IfExpr(self, node: IfExpr) -> str:
388 cond = self.visit(node.condition)
389 then = self.visit(node.then_expr)
390 els = self.visit(node.else_expr)
391 return f"({cond} ? {then} : {els})"
392
393 def visit_LetExpr(self, node: LetExpr) -> str:
394 val = self.visit(node.value)
395 body = self.visit(node.body)
396 return f"let {node.name} = {val} in {body}"
397
398 def visit_FuncCall(self, node: FuncCall) -> str:
399 args = ', '.join(self.visit(a) for a in node.args)
400 return f"{node.name}({args})"
401
402
403# ---------------------------------------------------------------------------
404# Demo
405# ---------------------------------------------------------------------------
406
407def demo(label: str, ast: Node) -> None:
408 print(f"\n{'─'*52}")
409 print(f"Expression: {label}")
410
411 pp = PrettyPrintVisitor()
412 printed = pp.visit(ast)
413 print(f" Pretty: {printed}")
414
415 tc = TypeCheckVisitor(type_env={'x': INT, 'y': FLOAT, 's': STR, 'flag': BOOL})
416 inferred = tc.visit(ast)
417 if tc.errors:
418 print(f" Type: ERROR")
419 for e in tc.errors:
420 print(f" {e}")
421 else:
422 print(f" Type: {inferred}")
423
424 ev = EvalVisitor(env={'x': 5, 'y': 2.0, 's': 'hello', 'flag': True})
425 try:
426 result = ev.visit(ast)
427 print(f" Value: {result!r}")
428 except EvalError as e:
429 print(f" Eval err: {e}")
430
431
432def main():
433 print("=" * 60)
434 print("AST Visitor Pattern Demo")
435 print("=" * 60)
436
437 # 1. Simple arithmetic: x * 2 + y
438 demo("x * 2 + y",
439 BinOp('+', BinOp('*', Var('x'), IntLit(2)), Var('y')))
440
441 # 2. Comparison: x > 3 and flag
442 demo("x > 3 and flag",
443 BinOp('and', BinOp('>', Var('x'), IntLit(3)), Var('flag')))
444
445 # 3. Ternary if: (flag ? x : -x)
446 demo("flag ? x : -x",
447 IfExpr(Var('flag'), Var('x'), UnaryOp('-', Var('x'))))
448
449 # 4. Let expression: let z = x + 1 in z * z
450 demo("let z = x + 1 in z * z",
451 LetExpr('z', BinOp('+', Var('x'), IntLit(1)),
452 BinOp('*', Var('z'), Var('z'))))
453
454 # 5. Function call: sqrt(x * x + y * y)
455 demo("sqrt(x * x + y * y)",
456 FuncCall('sqrt', [
457 BinOp('+',
458 BinOp('*', Var('x'), Var('x')),
459 BinOp('*', Var('y'), Var('y')))
460 ]))
461
462 # 6. String concatenation: s + " world"
463 demo('s + " world"',
464 BinOp('+', Var('s'), StrLit(' world')))
465
466 # 7. Type error: x + flag (int + bool)
467 demo("x + flag (type error)",
468 BinOp('+', Var('x'), Var('flag')))
469
470 # 8. Nested let: let a = 3 in let b = a + 2 in a * b
471 demo("let a = 3 in let b = a + 2 in a * b",
472 LetExpr('a', IntLit(3),
473 LetExpr('b', BinOp('+', Var('a'), IntLit(2)),
474 BinOp('*', Var('a'), Var('b')))))
475
476
477if __name__ == "__main__":
478 main()