07_type_checker.py

Download
python 627 lines 19.6 KB
  1"""
  207_type_checker.py - Symbol Table and Type Checker
  3
  4Demonstrates semantic analysis: symbol table management with nested
  5scopes, and type checking for a mini-language.
  6
  7Mini-language features:
  8  - Types: int, float, bool, string, void
  9  - Variables: declaration and use
 10  - Functions: declaration with typed parameters and return type
 11  - Control flow: if/else, while
 12  - Expressions: arithmetic, comparison, logical
 13  - Return statements
 14
 15The type checker:
 16  1. Builds a symbol table with nested scopes
 17  2. Checks that variables are declared before use
 18  3. Checks that assignments are type-compatible
 19  4. Checks function call arity and argument types
 20  5. Checks that return types match function signatures
 21  6. Reports all errors (not just the first)
 22
 23Topics covered:
 24  - Scope chain (linked symbol tables)
 25  - Type compatibility (int/float coercion)
 26  - Function type signatures
 27  - Two-pass: collect function declarations, then type-check bodies
 28  - Accumulating multiple type errors
 29"""
 30
 31from __future__ import annotations
 32from dataclasses import dataclass, field
 33from typing import Any, Optional
 34
 35
 36# ---------------------------------------------------------------------------
 37# Type system
 38# ---------------------------------------------------------------------------
 39
 40@dataclass(frozen=True)
 41class Type:
 42    name: str
 43
 44    def __repr__(self) -> str:
 45        return self.name
 46
 47
 48INT    = Type('int')
 49FLOAT  = Type('float')
 50BOOL   = Type('bool')
 51STRING = Type('string')
 52VOID   = Type('void')
 53ERROR  = Type('error')    # sentinel for type errors
 54
 55TYPE_MAP: dict[str, Type] = {
 56    'int': INT, 'float': FLOAT, 'bool': BOOL,
 57    'string': STRING, 'void': VOID,
 58}
 59
 60
 61def is_numeric(t: Type) -> bool:
 62    return t in (INT, FLOAT)
 63
 64
 65def numeric_result(t1: Type, t2: Type) -> Type:
 66    """int op int -> int; anything with float -> float."""
 67    if t1 == ERROR or t2 == ERROR:
 68        return ERROR
 69    if t1 == FLOAT or t2 == FLOAT:
 70        return FLOAT
 71    return INT
 72
 73
 74def compatible(expected: Type, actual: Type) -> bool:
 75    """Is 'actual' assignable to 'expected'?"""
 76    if expected == actual:
 77        return True
 78    # int is implicitly convertible to float
 79    if expected == FLOAT and actual == INT:
 80        return True
 81    return False
 82
 83
 84# ---------------------------------------------------------------------------
 85# Symbol Table
 86# ---------------------------------------------------------------------------
 87
 88@dataclass
 89class Symbol:
 90    name: str
 91    type: Type
 92    is_function: bool = False
 93    param_types: list[Type] = field(default_factory=list)
 94    return_type: Optional[Type] = None
 95    defined_at: int = 0    # line number
 96
 97
 98class Scope:
 99    """A single scope level (function body, block, etc.)."""
100
101    def __init__(self, name: str, parent: Optional[Scope] = None):
102        self.name = name
103        self.parent = parent
104        self._symbols: dict[str, Symbol] = {}
105
106    def define(self, sym: Symbol) -> bool:
107        """Define a symbol. Returns False if already defined in this scope."""
108        if sym.name in self._symbols:
109            return False
110        self._symbols[sym.name] = sym
111        return True
112
113    def lookup(self, name: str) -> Optional[Symbol]:
114        """Look up a name in this scope and all parent scopes."""
115        if name in self._symbols:
116            return self._symbols[name]
117        if self.parent:
118            return self.parent.lookup(name)
119        return None
120
121    def lookup_local(self, name: str) -> Optional[Symbol]:
122        """Look up only in this scope (not parents)."""
123        return self._symbols.get(name)
124
125    def depth(self) -> int:
126        d = 0
127        s = self
128        while s.parent:
129            d += 1
130            s = s.parent
131        return d
132
133    def __repr__(self) -> str:
134        syms = list(self._symbols.keys())
135        return f"Scope({self.name!r}, symbols={syms})"
136
137
138class SymbolTable:
139    """
140    The symbol table manages the scope stack.
141    Supports entering/exiting scopes and looking up symbols.
142    """
143
144    def __init__(self):
145        self.global_scope = Scope("global")
146        self.current: Scope = self.global_scope
147
148    def enter_scope(self, name: str) -> Scope:
149        new_scope = Scope(name, parent=self.current)
150        self.current = new_scope
151        return new_scope
152
153    def exit_scope(self) -> Scope:
154        exited = self.current
155        if self.current.parent:
156            self.current = self.current.parent
157        return exited
158
159    def define(self, sym: Symbol) -> bool:
160        return self.current.define(sym)
161
162    def lookup(self, name: str) -> Optional[Symbol]:
163        return self.current.lookup(name)
164
165    def depth(self) -> int:
166        return self.current.depth()
167
168
169# ---------------------------------------------------------------------------
170# AST Nodes (minimal, focus on type checking)
171# ---------------------------------------------------------------------------
172
173@dataclass
174class Program:
175    declarations: list
176
177@dataclass
178class FuncDecl:
179    name: str
180    return_type: str
181    params: list[tuple[str, str]]   # [(name, type_str), ...]
182    body: list                       # list of statements
183    line: int = 0
184
185@dataclass
186class VarDecl:
187    name: str
188    type_str: str
189    init: Optional[Any] = None
190    line: int = 0
191
192@dataclass
193class Assign:
194    name: str
195    value: Any
196    line: int = 0
197
198@dataclass
199class ReturnStmt:
200    value: Optional[Any] = None
201    line: int = 0
202
203@dataclass
204class IfStmt:
205    condition: Any
206    then_block: list
207    else_block: Optional[list] = None
208    line: int = 0
209
210@dataclass
211class WhileStmt:
212    condition: Any
213    body: list
214    line: int = 0
215
216@dataclass
217class ExprStmt:
218    expr: Any
219    line: int = 0
220
221@dataclass
222class IntLit:
223    value: int
224
225@dataclass
226class FloatLit:
227    value: float
228
229@dataclass
230class BoolLit:
231    value: bool
232
233@dataclass
234class StrLit:
235    value: str
236
237@dataclass
238class VarRef:
239    name: str
240    line: int = 0
241
242@dataclass
243class BinOp:
244    op: str
245    left: Any
246    right: Any
247    line: int = 0
248
249@dataclass
250class UnaryOp:
251    op: str
252    operand: Any
253    line: int = 0
254
255@dataclass
256class CallExpr:
257    name: str
258    args: list
259    line: int = 0
260
261
262# ---------------------------------------------------------------------------
263# Type Checker
264# ---------------------------------------------------------------------------
265
266class TypeChecker:
267    """
268    Walks the AST and performs type checking.
269    Errors are accumulated in self.errors (list of strings).
270    """
271
272    def __init__(self):
273        self.table = SymbolTable()
274        self.errors: list[str] = []
275        self._current_func_return: Optional[Type] = None
276
277    def error(self, msg: str, line: int = 0) -> Type:
278        loc = f" [line {line}]" if line else ""
279        self.errors.append(f"TypeError{loc}: {msg}")
280        return ERROR
281
282    # --- Top-level ---
283
284    def check_program(self, program: Program) -> None:
285        # First pass: register all function signatures (forward declarations)
286        for decl in program.declarations:
287            if isinstance(decl, FuncDecl):
288                self._register_func(decl)
289
290        # Second pass: type-check all declarations
291        for decl in program.declarations:
292            self.check_decl(decl)
293
294    def _register_func(self, decl: FuncDecl) -> None:
295        ret_type = TYPE_MAP.get(decl.return_type, ERROR)
296        param_types = [TYPE_MAP.get(pt, ERROR) for _, pt in decl.params]
297        sym = Symbol(
298            name=decl.name,
299            type=ret_type,
300            is_function=True,
301            param_types=param_types,
302            return_type=ret_type,
303            defined_at=decl.line,
304        )
305        if not self.table.define(sym):
306            self.error(f"Function {decl.name!r} already defined", decl.line)
307
308    def check_decl(self, decl) -> None:
309        if isinstance(decl, FuncDecl):
310            self.check_func(decl)
311        elif isinstance(decl, VarDecl):
312            self.check_var_decl(decl, global_scope=True)
313
314    def check_func(self, decl: FuncDecl) -> None:
315        ret_type = TYPE_MAP.get(decl.return_type, ERROR)
316        self._current_func_return = ret_type
317
318        self.table.enter_scope(f"func:{decl.name}")
319
320        # Define parameters in function scope
321        for pname, ptype_str in decl.params:
322            ptype = TYPE_MAP.get(ptype_str, ERROR)
323            if ptype == ERROR:
324                self.error(f"Unknown parameter type {ptype_str!r}", decl.line)
325            sym = Symbol(name=pname, type=ptype, defined_at=decl.line)
326            if not self.table.define(sym):
327                self.error(f"Duplicate parameter {pname!r}", decl.line)
328
329        for stmt in decl.body:
330            self.check_stmt(stmt)
331
332        self.table.exit_scope()
333        self._current_func_return = None
334
335    def check_var_decl(self, decl: VarDecl, global_scope: bool = False) -> None:
336        declared_type = TYPE_MAP.get(decl.type_str)
337        if declared_type is None:
338            self.error(f"Unknown type {decl.type_str!r}", decl.line)
339            declared_type = ERROR
340
341        if decl.init is not None:
342            init_type = self.check_expr(decl.init)
343            if init_type != ERROR and declared_type != ERROR:
344                if not compatible(declared_type, init_type):
345                    self.error(
346                        f"Cannot initialize {declared_type} variable with {init_type} value",
347                        decl.line
348                    )
349
350        sym = Symbol(name=decl.name, type=declared_type, defined_at=decl.line)
351        if not self.table.define(sym):
352            self.error(f"Variable {decl.name!r} already declared in this scope", decl.line)
353
354    def check_stmt(self, stmt) -> None:
355        if isinstance(stmt, VarDecl):
356            self.check_var_decl(stmt)
357        elif isinstance(stmt, Assign):
358            self.check_assign(stmt)
359        elif isinstance(stmt, ReturnStmt):
360            self.check_return(stmt)
361        elif isinstance(stmt, IfStmt):
362            self.check_if(stmt)
363        elif isinstance(stmt, WhileStmt):
364            self.check_while(stmt)
365        elif isinstance(stmt, ExprStmt):
366            self.check_expr(stmt.expr)
367
368    def check_assign(self, stmt: Assign) -> None:
369        sym = self.table.lookup(stmt.name)
370        if sym is None:
371            self.error(f"Undefined variable {stmt.name!r}", stmt.line)
372            return
373        val_type = self.check_expr(stmt.value)
374        if val_type != ERROR and sym.type != ERROR:
375            if not compatible(sym.type, val_type):
376                self.error(
377                    f"Cannot assign {val_type} to variable {stmt.name!r} of type {sym.type}",
378                    stmt.line
379                )
380
381    def check_return(self, stmt: ReturnStmt) -> None:
382        if self._current_func_return is None:
383            self.error("'return' outside function", stmt.line)
384            return
385        if stmt.value is None:
386            if self._current_func_return != VOID:
387                self.error(
388                    f"Function must return {self._current_func_return}, got void",
389                    stmt.line
390                )
391        else:
392            val_type = self.check_expr(stmt.value)
393            if val_type != ERROR and self._current_func_return != ERROR:
394                if not compatible(self._current_func_return, val_type):
395                    self.error(
396                        f"Return type mismatch: expected {self._current_func_return}, got {val_type}",
397                        stmt.line
398                    )
399
400    def check_if(self, stmt: IfStmt) -> None:
401        cond_type = self.check_expr(stmt.condition)
402        if cond_type not in (BOOL, ERROR):
403            self.error(f"If condition must be bool, got {cond_type}", stmt.line)
404        self.table.enter_scope("if-then")
405        for s in stmt.then_block:
406            self.check_stmt(s)
407        self.table.exit_scope()
408        if stmt.else_block is not None:
409            self.table.enter_scope("if-else")
410            for s in stmt.else_block:
411                self.check_stmt(s)
412            self.table.exit_scope()
413
414    def check_while(self, stmt: WhileStmt) -> None:
415        cond_type = self.check_expr(stmt.condition)
416        if cond_type not in (BOOL, ERROR):
417            self.error(f"While condition must be bool, got {cond_type}", stmt.line)
418        self.table.enter_scope("while-body")
419        for s in stmt.body:
420            self.check_stmt(s)
421        self.table.exit_scope()
422
423    def check_expr(self, expr) -> Type:
424        if isinstance(expr, IntLit):   return INT
425        if isinstance(expr, FloatLit): return FLOAT
426        if isinstance(expr, BoolLit):  return BOOL
427        if isinstance(expr, StrLit):   return STRING
428        if isinstance(expr, VarRef):
429            sym = self.table.lookup(expr.name)
430            if sym is None:
431                return self.error(f"Undefined variable {expr.name!r}", expr.line)
432            return sym.type
433        if isinstance(expr, BinOp):
434            return self.check_binop(expr)
435        if isinstance(expr, UnaryOp):
436            return self.check_unaryop(expr)
437        if isinstance(expr, CallExpr):
438            return self.check_call(expr)
439        self.error(f"Unknown expression type: {type(expr).__name__}")
440        return ERROR
441
442    def check_binop(self, expr: BinOp) -> Type:
443        lt = self.check_expr(expr.left)
444        rt = self.check_expr(expr.right)
445        op = expr.op
446        if op in ('+', '-', '*', '/'):
447            if op == '+' and lt == STRING and rt == STRING:
448                return STRING
449            if is_numeric(lt) and is_numeric(rt):
450                return numeric_result(lt, rt)
451            if lt != ERROR and rt != ERROR:
452                self.error(f"Operator {op!r} not valid for {lt} and {rt}", expr.line)
453            return ERROR
454        if op in ('<', '>', '<=', '>='):
455            if is_numeric(lt) and is_numeric(rt):
456                return BOOL
457            if lt != ERROR and rt != ERROR:
458                self.error(f"Comparison {op!r} not valid for {lt} and {rt}", expr.line)
459            return ERROR
460        if op in ('==', '!='):
461            if lt == rt or (is_numeric(lt) and is_numeric(rt)):
462                return BOOL
463            if lt != ERROR and rt != ERROR:
464                self.error(f"Cannot compare {lt} with {rt}", expr.line)
465            return ERROR
466        if op in ('&&', '||', 'and', 'or'):
467            if lt == BOOL and rt == BOOL:
468                return BOOL
469            if lt != ERROR and rt != ERROR:
470                self.error(f"Logical {op!r} requires bool operands", expr.line)
471            return ERROR
472        self.error(f"Unknown operator {op!r}", expr.line)
473        return ERROR
474
475    def check_unaryop(self, expr: UnaryOp) -> Type:
476        t = self.check_expr(expr.operand)
477        if expr.op == '-':
478            if is_numeric(t): return t
479            if t != ERROR: self.error(f"Unary '-' not valid for {t}", expr.line)
480            return ERROR
481        if expr.op in ('!', 'not'):
482            if t == BOOL: return BOOL
483            if t != ERROR: self.error(f"'not' requires bool, got {t}", expr.line)
484            return ERROR
485        self.error(f"Unknown unary operator {expr.op!r}", expr.line)
486        return ERROR
487
488    def check_call(self, expr: CallExpr) -> Type:
489        sym = self.table.lookup(expr.name)
490        if sym is None:
491            return self.error(f"Undefined function {expr.name!r}", expr.line)
492        if not sym.is_function:
493            return self.error(f"{expr.name!r} is not a function", expr.line)
494        if len(expr.args) != len(sym.param_types):
495            self.error(
496                f"Function {expr.name!r} expects {len(sym.param_types)} args, got {len(expr.args)}",
497                expr.line
498            )
499        for i, (arg, expected) in enumerate(zip(expr.args, sym.param_types)):
500            at = self.check_expr(arg)
501            if at != ERROR and expected != ERROR and not compatible(expected, at):
502                self.error(
503                    f"Argument {i+1} of {expr.name!r}: expected {expected}, got {at}",
504                    expr.line
505                )
506        return sym.return_type or VOID
507
508
509# ---------------------------------------------------------------------------
510# Demo
511# ---------------------------------------------------------------------------
512
513def build_program() -> Program:
514    """
515    Build an AST representing:
516
517    int add(int x, int y) { return x + y; }
518    float average(int a, int b) { return (a + b) / 2.0; }
519    void main() {
520        int result = add(3, 4);
521        float avg = average(10, 20);
522        bool flag = result > 5;
523        string msg = "done";
524        if (flag) {
525            int local = result * 2;
526        } else {
527            int local = 0;  // same name, different scope: ok
528        }
529        // Error: assigning float to int
530        result = avg;
531        // Error: wrong arg count
532        int bad = add(1, 2, 3);
533    }
534    """
535    add_func = FuncDecl(
536        name='add', return_type='int',
537        params=[('x', 'int'), ('y', 'int')],
538        body=[ReturnStmt(BinOp('+', VarRef('x'), VarRef('y')), line=2)],
539        line=1
540    )
541    avg_func = FuncDecl(
542        name='average', return_type='float',
543        params=[('a', 'int'), ('b', 'int')],
544        body=[ReturnStmt(
545            BinOp('/', BinOp('+', VarRef('a'), VarRef('b')), FloatLit(2.0)),
546            line=5
547        )],
548        line=4
549    )
550    main_func = FuncDecl(
551        name='main', return_type='void',
552        params=[],
553        body=[
554            VarDecl('result', 'int',
555                    CallExpr('add', [IntLit(3), IntLit(4)], line=8), line=8),
556            VarDecl('avg', 'float',
557                    CallExpr('average', [IntLit(10), IntLit(20)], line=9), line=9),
558            VarDecl('flag', 'bool',
559                    BinOp('>', VarRef('result'), IntLit(5), line=10), line=10),
560            VarDecl('msg', 'string', StrLit('done'), line=11),
561            IfStmt(
562                condition=VarRef('flag'),
563                then_block=[VarDecl('local', 'int',
564                                    BinOp('*', VarRef('result'), IntLit(2)), line=13)],
565                else_block=[VarDecl('local', 'int', IntLit(0), line=15)],
566                line=12
567            ),
568            # Intentional type error: assigning float to int
569            Assign('result', VarRef('avg'), line=18),
570            # Intentional arity error
571            VarDecl('bad', 'int',
572                    CallExpr('add', [IntLit(1), IntLit(2), IntLit(3)], line=20), line=20),
573        ],
574        line=7
575    )
576    return Program(declarations=[add_func, avg_func, main_func])
577
578
579def main():
580    print("=" * 60)
581    print("Symbol Table and Type Checker Demo")
582    print("=" * 60)
583
584    program = build_program()
585    tc = TypeChecker()
586    tc.check_program(program)
587
588    print("\nGlobal scope symbols:")
589    for name, sym in tc.table.global_scope._symbols.items():
590        if sym.is_function:
591            params = ', '.join(str(t) for t in sym.param_types)
592            print(f"  function {name}({params}) -> {sym.return_type}")
593        else:
594            print(f"  variable {name}: {sym.type}")
595
596    print(f"\nType checking complete.")
597    print(f"Errors found: {len(tc.errors)}")
598    if tc.errors:
599        print("\nError list:")
600        for err in tc.errors:
601            print(f"  {err}")
602    else:
603        print("  No errors.")
604
605    # Demonstrate correct program
606    print("\n--- Correct program (no errors) ---")
607    correct = Program([
608        FuncDecl('square', 'int', [('n', 'int')],
609                 [ReturnStmt(BinOp('*', VarRef('n'), VarRef('n')), line=1)], line=1),
610        FuncDecl('main', 'void', [],
611                 [VarDecl('x', 'int', IntLit(5), line=3),
612                  VarDecl('sq', 'int', CallExpr('square', [VarRef('x')], line=4), line=4),
613                  VarDecl('ok', 'bool', BinOp('>', VarRef('sq'), IntLit(10), line=5), line=5)],
614                 line=2)
615    ])
616    tc2 = TypeChecker()
617    tc2.check_program(correct)
618    print(f"Errors: {len(tc2.errors)}")
619    if tc2.errors:
620        for e in tc2.errors: print(f"  {e}")
621    else:
622        print("  No errors. Program is type-correct.")
623
624
625if __name__ == "__main__":
626    main()