06_ast_visitor.py

Download
python 479 lines 14.9 KB
  1"""
  206_ast_visitor.py - AST Node Classes and Visitor Pattern
  3
  4Demonstrates the Visitor design pattern applied to an Abstract Syntax Tree.
  5The visitor pattern separates tree traversal logic from node definitions,
  6making it easy to add new operations without modifying node classes.
  7
  8AST nodes are defined using Python dataclasses.
  9Three visitors are implemented:
 10  1. EvalVisitor    - evaluates the expression tree to a numeric value
 11  2. TypeCheckVisitor - infers and checks types (int, float, bool, str)
 12  3. PrettyPrintVisitor - produces a formatted source code string
 13
 14The language supports:
 15  - Numeric literals (int and float)
 16  - Boolean literals
 17  - String literals
 18  - Arithmetic: +, -, *, /, % (unary -)
 19  - Comparison: ==, !=, <, >, <=, >=
 20  - Logical: and, or, not
 21  - String concatenation with +
 22  - Variables and let bindings
 23  - If expressions (ternary style)
 24
 25Topics covered:
 26  - Dataclass-based AST nodes
 27  - Abstract base class for visitors
 28  - Double dispatch via visit_<NodeType> method naming
 29  - Type inference without annotations
 30  - Environment (symbol table) as a dict
 31"""
 32
 33from __future__ import annotations
 34from abc import ABC, abstractmethod
 35from dataclasses import dataclass, field
 36from typing import Any, Optional, Union
 37
 38
 39# ---------------------------------------------------------------------------
 40# AST Node definitions
 41# ---------------------------------------------------------------------------
 42
 43class Node:
 44    """Base class for all AST nodes."""
 45    pass
 46
 47
 48@dataclass
 49class IntLit(Node):
 50    value: int
 51
 52@dataclass
 53class FloatLit(Node):
 54    value: float
 55
 56@dataclass
 57class BoolLit(Node):
 58    value: bool
 59
 60@dataclass
 61class StrLit(Node):
 62    value: str
 63
 64@dataclass
 65class Var(Node):
 66    name: str
 67
 68@dataclass
 69class BinOp(Node):
 70    op: str     # '+', '-', '*', '/', '%', '==', '!=', '<', '>', '<=', '>=', 'and', 'or'
 71    left: Node
 72    right: Node
 73
 74@dataclass
 75class UnaryOp(Node):
 76    op: str     # '-', 'not'
 77    operand: Node
 78
 79@dataclass
 80class IfExpr(Node):
 81    """Ternary if: condition ? then_expr : else_expr"""
 82    condition: Node
 83    then_expr: Node
 84    else_expr: Node
 85
 86@dataclass
 87class LetExpr(Node):
 88    """let name = value in body"""
 89    name: str
 90    value: Node
 91    body: Node
 92
 93@dataclass
 94class FuncCall(Node):
 95    name: str
 96    args: list[Node] = field(default_factory=list)
 97
 98
 99# ---------------------------------------------------------------------------
100# Visitor base class
101# ---------------------------------------------------------------------------
102
103class Visitor(ABC):
104    """
105    Abstract visitor. Each concrete visitor implements visit_<NodeType>
106    methods. The dispatch method calls the appropriate visit method.
107    """
108
109    def visit(self, node: Node) -> Any:
110        """Dispatch to the appropriate visit_* method."""
111        method_name = f"visit_{type(node).__name__}"
112        method = getattr(self, method_name, self.generic_visit)
113        return method(node)
114
115    def generic_visit(self, node: Node) -> Any:
116        raise NotImplementedError(
117            f"{type(self).__name__} has no handler for {type(node).__name__}"
118        )
119
120
121# ---------------------------------------------------------------------------
122# Visitor 1: Evaluator
123# ---------------------------------------------------------------------------
124
125class EvalError(Exception):
126    pass
127
128
129class EvalVisitor(Visitor):
130    """
131    Evaluates an AST expression to a Python value.
132    Maintains an environment mapping variable names to values.
133    """
134
135    def __init__(self, env: Optional[dict[str, Any]] = None):
136        self.env: dict[str, Any] = env or {}
137
138    def visit_IntLit(self, node: IntLit) -> int:
139        return node.value
140
141    def visit_FloatLit(self, node: FloatLit) -> float:
142        return node.value
143
144    def visit_BoolLit(self, node: BoolLit) -> bool:
145        return node.value
146
147    def visit_StrLit(self, node: StrLit) -> str:
148        return node.value
149
150    def visit_Var(self, node: Var) -> Any:
151        if node.name not in self.env:
152            raise EvalError(f"Undefined variable: {node.name!r}")
153        return self.env[node.name]
154
155    def visit_BinOp(self, node: BinOp) -> Any:
156        l = self.visit(node.left)
157        # Short-circuit evaluation for 'and' and 'or'
158        if node.op == 'and':
159            return l and self.visit(node.right)
160        if node.op == 'or':
161            return l or self.visit(node.right)
162        r = self.visit(node.right)
163        ops = {
164            '+':  lambda a, b: a + b,
165            '-':  lambda a, b: a - b,
166            '*':  lambda a, b: a * b,
167            '/':  lambda a, b: a / b,
168            '%':  lambda a, b: a % b,
169            '==': lambda a, b: a == b,
170            '!=': lambda a, b: a != b,
171            '<':  lambda a, b: a < b,
172            '>':  lambda a, b: a > b,
173            '<=': lambda a, b: a <= b,
174            '>=': lambda a, b: a >= b,
175        }
176        if node.op not in ops:
177            raise EvalError(f"Unknown operator: {node.op!r}")
178        return ops[node.op](l, r)
179
180    def visit_UnaryOp(self, node: UnaryOp) -> Any:
181        v = self.visit(node.operand)
182        if node.op == '-':    return -v
183        if node.op == 'not':  return not v
184        raise EvalError(f"Unknown unary operator: {node.op!r}")
185
186    def visit_IfExpr(self, node: IfExpr) -> Any:
187        cond = self.visit(node.condition)
188        return self.visit(node.then_expr if cond else node.else_expr)
189
190    def visit_LetExpr(self, node: LetExpr) -> Any:
191        val = self.visit(node.value)
192        old = self.env.get(node.name)
193        self.env[node.name] = val
194        result = self.visit(node.body)
195        # Restore previous binding (lexical scoping)
196        if old is None:
197            self.env.pop(node.name, None)
198        else:
199            self.env[node.name] = old
200        return result
201
202    def visit_FuncCall(self, node: FuncCall) -> Any:
203        args = [self.visit(a) for a in node.args]
204        builtins = {
205            'abs':   lambda x: abs(x),
206            'min':   lambda *xs: min(xs),
207            'max':   lambda *xs: max(xs),
208            'sqrt':  lambda x: x ** 0.5,
209            'len':   lambda s: len(s),
210            'str':   lambda x: str(x),
211            'int':   lambda x: int(x),
212            'float': lambda x: float(x),
213        }
214        if node.name in builtins:
215            return builtins[node.name](*args)
216        raise EvalError(f"Unknown function: {node.name!r}")
217
218
219# ---------------------------------------------------------------------------
220# Visitor 2: Type Checker
221# ---------------------------------------------------------------------------
222
223class TypeError_(Exception):
224    pass
225
226
227# Simple type tags
228INT   = 'int'
229FLOAT = 'float'
230BOOL  = 'bool'
231STR   = 'str'
232NUM   = 'num'    # int or float
233
234
235class TypeCheckVisitor(Visitor):
236    """
237    Infers the type of an expression.
238    Reports TypeError_ for type mismatches.
239    Uses a type environment mapping variable names to types.
240    """
241
242    def __init__(self, type_env: Optional[dict[str, str]] = None):
243        self.type_env: dict[str, str] = type_env or {}
244        self.errors: list[str] = []
245
246    def _error(self, msg: str) -> str:
247        self.errors.append(msg)
248        return 'error'
249
250    def visit_IntLit(self, node: IntLit) -> str:       return INT
251    def visit_FloatLit(self, node: FloatLit) -> str:   return FLOAT
252    def visit_BoolLit(self, node: BoolLit) -> str:     return BOOL
253    def visit_StrLit(self, node: StrLit) -> str:       return STR
254
255    def visit_Var(self, node: Var) -> str:
256        if node.name not in self.type_env:
257            return self._error(f"Undefined variable: {node.name!r}")
258        return self.type_env[node.name]
259
260    def _is_numeric(self, t: str) -> bool:
261        return t in (INT, FLOAT)
262
263    def _numeric_result(self, t1: str, t2: str) -> str:
264        """int op int -> int; float op float -> float; int op float -> float."""
265        if t1 == FLOAT or t2 == FLOAT:
266            return FLOAT
267        return INT
268
269    def visit_BinOp(self, node: BinOp) -> str:
270        lt = self.visit(node.left)
271        rt = self.visit(node.right)
272
273        if node.op in ('+', '-', '*', '/', '%'):
274            if node.op == '+' and lt == STR and rt == STR:
275                return STR   # string concatenation
276            if self._is_numeric(lt) and self._is_numeric(rt):
277                return self._numeric_result(lt, rt)
278            return self._error(
279                f"Operator {node.op!r} not applicable to {lt} and {rt}"
280            )
281
282        if node.op in ('<', '>', '<=', '>='):
283            if self._is_numeric(lt) and self._is_numeric(rt):
284                return BOOL
285            return self._error(
286                f"Comparison {node.op!r} not applicable to {lt} and {rt}"
287            )
288
289        if node.op in ('==', '!='):
290            if lt == rt or (self._is_numeric(lt) and self._is_numeric(rt)):
291                return BOOL
292            return self._error(f"Cannot compare {lt} with {rt}")
293
294        if node.op in ('and', 'or'):
295            if lt == BOOL and rt == BOOL:
296                return BOOL
297            return self._error(
298                f"Logical {node.op!r} requires bool operands, got {lt} and {rt}"
299            )
300
301        return self._error(f"Unknown operator: {node.op!r}")
302
303    def visit_UnaryOp(self, node: UnaryOp) -> str:
304        t = self.visit(node.operand)
305        if node.op == '-':
306            if self._is_numeric(t): return t
307            return self._error(f"Unary '-' not applicable to {t}")
308        if node.op == 'not':
309            if t == BOOL: return BOOL
310            return self._error(f"'not' requires bool, got {t}")
311        return self._error(f"Unknown unary operator: {node.op!r}")
312
313    def visit_IfExpr(self, node: IfExpr) -> str:
314        ct = self.visit(node.condition)
315        if ct != BOOL:
316            self._error(f"If condition must be bool, got {ct}")
317        tt = self.visit(node.then_expr)
318        et = self.visit(node.else_expr)
319        if tt != et and not (self._is_numeric(tt) and self._is_numeric(et)):
320            self._error(f"If branches have different types: {tt} vs {et}")
321        return tt
322
323    def visit_LetExpr(self, node: LetExpr) -> str:
324        val_type = self.visit(node.value)
325        old = self.type_env.get(node.name)
326        self.type_env[node.name] = val_type
327        body_type = self.visit(node.body)
328        if old is None:
329            self.type_env.pop(node.name, None)
330        else:
331            self.type_env[node.name] = old
332        return body_type
333
334    def visit_FuncCall(self, node: FuncCall) -> str:
335        arg_types = [self.visit(a) for a in node.args]
336        sigs: dict[str, str] = {
337            'abs': INT, 'sqrt': FLOAT, 'len': INT,
338            'str': STR, 'int': INT, 'float': FLOAT,
339            'min': FLOAT, 'max': FLOAT,
340        }
341        return sigs.get(node.name, 'unknown')
342
343
344# ---------------------------------------------------------------------------
345# Visitor 3: Pretty Printer
346# ---------------------------------------------------------------------------
347
348class PrettyPrintVisitor(Visitor):
349    """
350    Produces a readable string representation of the AST.
351    Adds parentheses based on operator precedence.
352    """
353
354    PREC = {
355        'or': 1, 'and': 2,
356        '==': 3, '!=': 3,
357        '<': 4, '>': 4, '<=': 4, '>=': 4,
358        '+': 5, '-': 5,
359        '*': 6, '/': 6, '%': 6,
360    }
361
362    def _prec(self, op: str) -> int:
363        return self.PREC.get(op, 99)
364
365    def visit_IntLit(self, node: IntLit) -> str:      return str(node.value)
366    def visit_FloatLit(self, node: FloatLit) -> str:  return str(node.value)
367    def visit_BoolLit(self, node: BoolLit) -> str:    return 'true' if node.value else 'false'
368    def visit_StrLit(self, node: StrLit) -> str:      return f'"{node.value}"'
369    def visit_Var(self, node: Var) -> str:             return node.name
370
371    def visit_BinOp(self, node: BinOp) -> str:
372        l = self.visit(node.left)
373        r = self.visit(node.right)
374        # Add parens for sub-expressions with lower precedence
375        if isinstance(node.left, BinOp) and self._prec(node.left.op) < self._prec(node.op):
376            l = f"({l})"
377        if isinstance(node.right, BinOp) and self._prec(node.right.op) <= self._prec(node.op):
378            r = f"({r})"
379        return f"{l} {node.op} {r}"
380
381    def visit_UnaryOp(self, node: UnaryOp) -> str:
382        inner = self.visit(node.operand)
383        if isinstance(node.operand, BinOp):
384            inner = f"({inner})"
385        return f"{node.op}{inner}"
386
387    def visit_IfExpr(self, node: IfExpr) -> str:
388        cond = self.visit(node.condition)
389        then = self.visit(node.then_expr)
390        els  = self.visit(node.else_expr)
391        return f"({cond} ? {then} : {els})"
392
393    def visit_LetExpr(self, node: LetExpr) -> str:
394        val  = self.visit(node.value)
395        body = self.visit(node.body)
396        return f"let {node.name} = {val} in {body}"
397
398    def visit_FuncCall(self, node: FuncCall) -> str:
399        args = ', '.join(self.visit(a) for a in node.args)
400        return f"{node.name}({args})"
401
402
403# ---------------------------------------------------------------------------
404# Demo
405# ---------------------------------------------------------------------------
406
407def demo(label: str, ast: Node) -> None:
408    print(f"\n{'─'*52}")
409    print(f"Expression: {label}")
410
411    pp = PrettyPrintVisitor()
412    printed = pp.visit(ast)
413    print(f"  Pretty:    {printed}")
414
415    tc = TypeCheckVisitor(type_env={'x': INT, 'y': FLOAT, 's': STR, 'flag': BOOL})
416    inferred = tc.visit(ast)
417    if tc.errors:
418        print(f"  Type:      ERROR")
419        for e in tc.errors:
420            print(f"             {e}")
421    else:
422        print(f"  Type:      {inferred}")
423
424    ev = EvalVisitor(env={'x': 5, 'y': 2.0, 's': 'hello', 'flag': True})
425    try:
426        result = ev.visit(ast)
427        print(f"  Value:     {result!r}")
428    except EvalError as e:
429        print(f"  Eval err:  {e}")
430
431
432def main():
433    print("=" * 60)
434    print("AST Visitor Pattern Demo")
435    print("=" * 60)
436
437    # 1. Simple arithmetic: x * 2 + y
438    demo("x * 2 + y",
439         BinOp('+', BinOp('*', Var('x'), IntLit(2)), Var('y')))
440
441    # 2. Comparison: x > 3 and flag
442    demo("x > 3 and flag",
443         BinOp('and', BinOp('>', Var('x'), IntLit(3)), Var('flag')))
444
445    # 3. Ternary if: (flag ? x : -x)
446    demo("flag ? x : -x",
447         IfExpr(Var('flag'), Var('x'), UnaryOp('-', Var('x'))))
448
449    # 4. Let expression: let z = x + 1 in z * z
450    demo("let z = x + 1 in z * z",
451         LetExpr('z', BinOp('+', Var('x'), IntLit(1)),
452                 BinOp('*', Var('z'), Var('z'))))
453
454    # 5. Function call: sqrt(x * x + y * y)
455    demo("sqrt(x * x + y * y)",
456         FuncCall('sqrt', [
457             BinOp('+',
458                   BinOp('*', Var('x'), Var('x')),
459                   BinOp('*', Var('y'), Var('y')))
460         ]))
461
462    # 6. String concatenation: s + " world"
463    demo('s + " world"',
464         BinOp('+', Var('s'), StrLit(' world')))
465
466    # 7. Type error: x + flag (int + bool)
467    demo("x + flag (type error)",
468         BinOp('+', Var('x'), Var('flag')))
469
470    # 8. Nested let: let a = 3 in let b = a + 2 in a * b
471    demo("let a = 3 in let b = a + 2 in a * b",
472         LetExpr('a', IntLit(3),
473                 LetExpr('b', BinOp('+', Var('a'), IntLit(2)),
474                         BinOp('*', Var('a'), Var('b')))))
475
476
477if __name__ == "__main__":
478    main()