1"""
207_type_checker.py - Symbol Table and Type Checker
3
4Demonstrates semantic analysis: symbol table management with nested
5scopes, and type checking for a mini-language.
6
7Mini-language features:
8 - Types: int, float, bool, string, void
9 - Variables: declaration and use
10 - Functions: declaration with typed parameters and return type
11 - Control flow: if/else, while
12 - Expressions: arithmetic, comparison, logical
13 - Return statements
14
15The type checker:
16 1. Builds a symbol table with nested scopes
17 2. Checks that variables are declared before use
18 3. Checks that assignments are type-compatible
19 4. Checks function call arity and argument types
20 5. Checks that return types match function signatures
21 6. Reports all errors (not just the first)
22
23Topics covered:
24 - Scope chain (linked symbol tables)
25 - Type compatibility (int/float coercion)
26 - Function type signatures
27 - Two-pass: collect function declarations, then type-check bodies
28 - Accumulating multiple type errors
29"""
30
31from __future__ import annotations
32from dataclasses import dataclass, field
33from typing import Any, Optional
34
35
36# ---------------------------------------------------------------------------
37# Type system
38# ---------------------------------------------------------------------------
39
40@dataclass(frozen=True)
41class Type:
42 name: str
43
44 def __repr__(self) -> str:
45 return self.name
46
47
48INT = Type('int')
49FLOAT = Type('float')
50BOOL = Type('bool')
51STRING = Type('string')
52VOID = Type('void')
53ERROR = Type('error') # sentinel for type errors
54
55TYPE_MAP: dict[str, Type] = {
56 'int': INT, 'float': FLOAT, 'bool': BOOL,
57 'string': STRING, 'void': VOID,
58}
59
60
61def is_numeric(t: Type) -> bool:
62 return t in (INT, FLOAT)
63
64
65def numeric_result(t1: Type, t2: Type) -> Type:
66 """int op int -> int; anything with float -> float."""
67 if t1 == ERROR or t2 == ERROR:
68 return ERROR
69 if t1 == FLOAT or t2 == FLOAT:
70 return FLOAT
71 return INT
72
73
74def compatible(expected: Type, actual: Type) -> bool:
75 """Is 'actual' assignable to 'expected'?"""
76 if expected == actual:
77 return True
78 # int is implicitly convertible to float
79 if expected == FLOAT and actual == INT:
80 return True
81 return False
82
83
84# ---------------------------------------------------------------------------
85# Symbol Table
86# ---------------------------------------------------------------------------
87
88@dataclass
89class Symbol:
90 name: str
91 type: Type
92 is_function: bool = False
93 param_types: list[Type] = field(default_factory=list)
94 return_type: Optional[Type] = None
95 defined_at: int = 0 # line number
96
97
98class Scope:
99 """A single scope level (function body, block, etc.)."""
100
101 def __init__(self, name: str, parent: Optional[Scope] = None):
102 self.name = name
103 self.parent = parent
104 self._symbols: dict[str, Symbol] = {}
105
106 def define(self, sym: Symbol) -> bool:
107 """Define a symbol. Returns False if already defined in this scope."""
108 if sym.name in self._symbols:
109 return False
110 self._symbols[sym.name] = sym
111 return True
112
113 def lookup(self, name: str) -> Optional[Symbol]:
114 """Look up a name in this scope and all parent scopes."""
115 if name in self._symbols:
116 return self._symbols[name]
117 if self.parent:
118 return self.parent.lookup(name)
119 return None
120
121 def lookup_local(self, name: str) -> Optional[Symbol]:
122 """Look up only in this scope (not parents)."""
123 return self._symbols.get(name)
124
125 def depth(self) -> int:
126 d = 0
127 s = self
128 while s.parent:
129 d += 1
130 s = s.parent
131 return d
132
133 def __repr__(self) -> str:
134 syms = list(self._symbols.keys())
135 return f"Scope({self.name!r}, symbols={syms})"
136
137
138class SymbolTable:
139 """
140 The symbol table manages the scope stack.
141 Supports entering/exiting scopes and looking up symbols.
142 """
143
144 def __init__(self):
145 self.global_scope = Scope("global")
146 self.current: Scope = self.global_scope
147
148 def enter_scope(self, name: str) -> Scope:
149 new_scope = Scope(name, parent=self.current)
150 self.current = new_scope
151 return new_scope
152
153 def exit_scope(self) -> Scope:
154 exited = self.current
155 if self.current.parent:
156 self.current = self.current.parent
157 return exited
158
159 def define(self, sym: Symbol) -> bool:
160 return self.current.define(sym)
161
162 def lookup(self, name: str) -> Optional[Symbol]:
163 return self.current.lookup(name)
164
165 def depth(self) -> int:
166 return self.current.depth()
167
168
169# ---------------------------------------------------------------------------
170# AST Nodes (minimal, focus on type checking)
171# ---------------------------------------------------------------------------
172
173@dataclass
174class Program:
175 declarations: list
176
177@dataclass
178class FuncDecl:
179 name: str
180 return_type: str
181 params: list[tuple[str, str]] # [(name, type_str), ...]
182 body: list # list of statements
183 line: int = 0
184
185@dataclass
186class VarDecl:
187 name: str
188 type_str: str
189 init: Optional[Any] = None
190 line: int = 0
191
192@dataclass
193class Assign:
194 name: str
195 value: Any
196 line: int = 0
197
198@dataclass
199class ReturnStmt:
200 value: Optional[Any] = None
201 line: int = 0
202
203@dataclass
204class IfStmt:
205 condition: Any
206 then_block: list
207 else_block: Optional[list] = None
208 line: int = 0
209
210@dataclass
211class WhileStmt:
212 condition: Any
213 body: list
214 line: int = 0
215
216@dataclass
217class ExprStmt:
218 expr: Any
219 line: int = 0
220
221@dataclass
222class IntLit:
223 value: int
224
225@dataclass
226class FloatLit:
227 value: float
228
229@dataclass
230class BoolLit:
231 value: bool
232
233@dataclass
234class StrLit:
235 value: str
236
237@dataclass
238class VarRef:
239 name: str
240 line: int = 0
241
242@dataclass
243class BinOp:
244 op: str
245 left: Any
246 right: Any
247 line: int = 0
248
249@dataclass
250class UnaryOp:
251 op: str
252 operand: Any
253 line: int = 0
254
255@dataclass
256class CallExpr:
257 name: str
258 args: list
259 line: int = 0
260
261
262# ---------------------------------------------------------------------------
263# Type Checker
264# ---------------------------------------------------------------------------
265
266class TypeChecker:
267 """
268 Walks the AST and performs type checking.
269 Errors are accumulated in self.errors (list of strings).
270 """
271
272 def __init__(self):
273 self.table = SymbolTable()
274 self.errors: list[str] = []
275 self._current_func_return: Optional[Type] = None
276
277 def error(self, msg: str, line: int = 0) -> Type:
278 loc = f" [line {line}]" if line else ""
279 self.errors.append(f"TypeError{loc}: {msg}")
280 return ERROR
281
282 # --- Top-level ---
283
284 def check_program(self, program: Program) -> None:
285 # First pass: register all function signatures (forward declarations)
286 for decl in program.declarations:
287 if isinstance(decl, FuncDecl):
288 self._register_func(decl)
289
290 # Second pass: type-check all declarations
291 for decl in program.declarations:
292 self.check_decl(decl)
293
294 def _register_func(self, decl: FuncDecl) -> None:
295 ret_type = TYPE_MAP.get(decl.return_type, ERROR)
296 param_types = [TYPE_MAP.get(pt, ERROR) for _, pt in decl.params]
297 sym = Symbol(
298 name=decl.name,
299 type=ret_type,
300 is_function=True,
301 param_types=param_types,
302 return_type=ret_type,
303 defined_at=decl.line,
304 )
305 if not self.table.define(sym):
306 self.error(f"Function {decl.name!r} already defined", decl.line)
307
308 def check_decl(self, decl) -> None:
309 if isinstance(decl, FuncDecl):
310 self.check_func(decl)
311 elif isinstance(decl, VarDecl):
312 self.check_var_decl(decl, global_scope=True)
313
314 def check_func(self, decl: FuncDecl) -> None:
315 ret_type = TYPE_MAP.get(decl.return_type, ERROR)
316 self._current_func_return = ret_type
317
318 self.table.enter_scope(f"func:{decl.name}")
319
320 # Define parameters in function scope
321 for pname, ptype_str in decl.params:
322 ptype = TYPE_MAP.get(ptype_str, ERROR)
323 if ptype == ERROR:
324 self.error(f"Unknown parameter type {ptype_str!r}", decl.line)
325 sym = Symbol(name=pname, type=ptype, defined_at=decl.line)
326 if not self.table.define(sym):
327 self.error(f"Duplicate parameter {pname!r}", decl.line)
328
329 for stmt in decl.body:
330 self.check_stmt(stmt)
331
332 self.table.exit_scope()
333 self._current_func_return = None
334
335 def check_var_decl(self, decl: VarDecl, global_scope: bool = False) -> None:
336 declared_type = TYPE_MAP.get(decl.type_str)
337 if declared_type is None:
338 self.error(f"Unknown type {decl.type_str!r}", decl.line)
339 declared_type = ERROR
340
341 if decl.init is not None:
342 init_type = self.check_expr(decl.init)
343 if init_type != ERROR and declared_type != ERROR:
344 if not compatible(declared_type, init_type):
345 self.error(
346 f"Cannot initialize {declared_type} variable with {init_type} value",
347 decl.line
348 )
349
350 sym = Symbol(name=decl.name, type=declared_type, defined_at=decl.line)
351 if not self.table.define(sym):
352 self.error(f"Variable {decl.name!r} already declared in this scope", decl.line)
353
354 def check_stmt(self, stmt) -> None:
355 if isinstance(stmt, VarDecl):
356 self.check_var_decl(stmt)
357 elif isinstance(stmt, Assign):
358 self.check_assign(stmt)
359 elif isinstance(stmt, ReturnStmt):
360 self.check_return(stmt)
361 elif isinstance(stmt, IfStmt):
362 self.check_if(stmt)
363 elif isinstance(stmt, WhileStmt):
364 self.check_while(stmt)
365 elif isinstance(stmt, ExprStmt):
366 self.check_expr(stmt.expr)
367
368 def check_assign(self, stmt: Assign) -> None:
369 sym = self.table.lookup(stmt.name)
370 if sym is None:
371 self.error(f"Undefined variable {stmt.name!r}", stmt.line)
372 return
373 val_type = self.check_expr(stmt.value)
374 if val_type != ERROR and sym.type != ERROR:
375 if not compatible(sym.type, val_type):
376 self.error(
377 f"Cannot assign {val_type} to variable {stmt.name!r} of type {sym.type}",
378 stmt.line
379 )
380
381 def check_return(self, stmt: ReturnStmt) -> None:
382 if self._current_func_return is None:
383 self.error("'return' outside function", stmt.line)
384 return
385 if stmt.value is None:
386 if self._current_func_return != VOID:
387 self.error(
388 f"Function must return {self._current_func_return}, got void",
389 stmt.line
390 )
391 else:
392 val_type = self.check_expr(stmt.value)
393 if val_type != ERROR and self._current_func_return != ERROR:
394 if not compatible(self._current_func_return, val_type):
395 self.error(
396 f"Return type mismatch: expected {self._current_func_return}, got {val_type}",
397 stmt.line
398 )
399
400 def check_if(self, stmt: IfStmt) -> None:
401 cond_type = self.check_expr(stmt.condition)
402 if cond_type not in (BOOL, ERROR):
403 self.error(f"If condition must be bool, got {cond_type}", stmt.line)
404 self.table.enter_scope("if-then")
405 for s in stmt.then_block:
406 self.check_stmt(s)
407 self.table.exit_scope()
408 if stmt.else_block is not None:
409 self.table.enter_scope("if-else")
410 for s in stmt.else_block:
411 self.check_stmt(s)
412 self.table.exit_scope()
413
414 def check_while(self, stmt: WhileStmt) -> None:
415 cond_type = self.check_expr(stmt.condition)
416 if cond_type not in (BOOL, ERROR):
417 self.error(f"While condition must be bool, got {cond_type}", stmt.line)
418 self.table.enter_scope("while-body")
419 for s in stmt.body:
420 self.check_stmt(s)
421 self.table.exit_scope()
422
423 def check_expr(self, expr) -> Type:
424 if isinstance(expr, IntLit): return INT
425 if isinstance(expr, FloatLit): return FLOAT
426 if isinstance(expr, BoolLit): return BOOL
427 if isinstance(expr, StrLit): return STRING
428 if isinstance(expr, VarRef):
429 sym = self.table.lookup(expr.name)
430 if sym is None:
431 return self.error(f"Undefined variable {expr.name!r}", expr.line)
432 return sym.type
433 if isinstance(expr, BinOp):
434 return self.check_binop(expr)
435 if isinstance(expr, UnaryOp):
436 return self.check_unaryop(expr)
437 if isinstance(expr, CallExpr):
438 return self.check_call(expr)
439 self.error(f"Unknown expression type: {type(expr).__name__}")
440 return ERROR
441
442 def check_binop(self, expr: BinOp) -> Type:
443 lt = self.check_expr(expr.left)
444 rt = self.check_expr(expr.right)
445 op = expr.op
446 if op in ('+', '-', '*', '/'):
447 if op == '+' and lt == STRING and rt == STRING:
448 return STRING
449 if is_numeric(lt) and is_numeric(rt):
450 return numeric_result(lt, rt)
451 if lt != ERROR and rt != ERROR:
452 self.error(f"Operator {op!r} not valid for {lt} and {rt}", expr.line)
453 return ERROR
454 if op in ('<', '>', '<=', '>='):
455 if is_numeric(lt) and is_numeric(rt):
456 return BOOL
457 if lt != ERROR and rt != ERROR:
458 self.error(f"Comparison {op!r} not valid for {lt} and {rt}", expr.line)
459 return ERROR
460 if op in ('==', '!='):
461 if lt == rt or (is_numeric(lt) and is_numeric(rt)):
462 return BOOL
463 if lt != ERROR and rt != ERROR:
464 self.error(f"Cannot compare {lt} with {rt}", expr.line)
465 return ERROR
466 if op in ('&&', '||', 'and', 'or'):
467 if lt == BOOL and rt == BOOL:
468 return BOOL
469 if lt != ERROR and rt != ERROR:
470 self.error(f"Logical {op!r} requires bool operands", expr.line)
471 return ERROR
472 self.error(f"Unknown operator {op!r}", expr.line)
473 return ERROR
474
475 def check_unaryop(self, expr: UnaryOp) -> Type:
476 t = self.check_expr(expr.operand)
477 if expr.op == '-':
478 if is_numeric(t): return t
479 if t != ERROR: self.error(f"Unary '-' not valid for {t}", expr.line)
480 return ERROR
481 if expr.op in ('!', 'not'):
482 if t == BOOL: return BOOL
483 if t != ERROR: self.error(f"'not' requires bool, got {t}", expr.line)
484 return ERROR
485 self.error(f"Unknown unary operator {expr.op!r}", expr.line)
486 return ERROR
487
488 def check_call(self, expr: CallExpr) -> Type:
489 sym = self.table.lookup(expr.name)
490 if sym is None:
491 return self.error(f"Undefined function {expr.name!r}", expr.line)
492 if not sym.is_function:
493 return self.error(f"{expr.name!r} is not a function", expr.line)
494 if len(expr.args) != len(sym.param_types):
495 self.error(
496 f"Function {expr.name!r} expects {len(sym.param_types)} args, got {len(expr.args)}",
497 expr.line
498 )
499 for i, (arg, expected) in enumerate(zip(expr.args, sym.param_types)):
500 at = self.check_expr(arg)
501 if at != ERROR and expected != ERROR and not compatible(expected, at):
502 self.error(
503 f"Argument {i+1} of {expr.name!r}: expected {expected}, got {at}",
504 expr.line
505 )
506 return sym.return_type or VOID
507
508
509# ---------------------------------------------------------------------------
510# Demo
511# ---------------------------------------------------------------------------
512
513def build_program() -> Program:
514 """
515 Build an AST representing:
516
517 int add(int x, int y) { return x + y; }
518 float average(int a, int b) { return (a + b) / 2.0; }
519 void main() {
520 int result = add(3, 4);
521 float avg = average(10, 20);
522 bool flag = result > 5;
523 string msg = "done";
524 if (flag) {
525 int local = result * 2;
526 } else {
527 int local = 0; // same name, different scope: ok
528 }
529 // Error: assigning float to int
530 result = avg;
531 // Error: wrong arg count
532 int bad = add(1, 2, 3);
533 }
534 """
535 add_func = FuncDecl(
536 name='add', return_type='int',
537 params=[('x', 'int'), ('y', 'int')],
538 body=[ReturnStmt(BinOp('+', VarRef('x'), VarRef('y')), line=2)],
539 line=1
540 )
541 avg_func = FuncDecl(
542 name='average', return_type='float',
543 params=[('a', 'int'), ('b', 'int')],
544 body=[ReturnStmt(
545 BinOp('/', BinOp('+', VarRef('a'), VarRef('b')), FloatLit(2.0)),
546 line=5
547 )],
548 line=4
549 )
550 main_func = FuncDecl(
551 name='main', return_type='void',
552 params=[],
553 body=[
554 VarDecl('result', 'int',
555 CallExpr('add', [IntLit(3), IntLit(4)], line=8), line=8),
556 VarDecl('avg', 'float',
557 CallExpr('average', [IntLit(10), IntLit(20)], line=9), line=9),
558 VarDecl('flag', 'bool',
559 BinOp('>', VarRef('result'), IntLit(5), line=10), line=10),
560 VarDecl('msg', 'string', StrLit('done'), line=11),
561 IfStmt(
562 condition=VarRef('flag'),
563 then_block=[VarDecl('local', 'int',
564 BinOp('*', VarRef('result'), IntLit(2)), line=13)],
565 else_block=[VarDecl('local', 'int', IntLit(0), line=15)],
566 line=12
567 ),
568 # Intentional type error: assigning float to int
569 Assign('result', VarRef('avg'), line=18),
570 # Intentional arity error
571 VarDecl('bad', 'int',
572 CallExpr('add', [IntLit(1), IntLit(2), IntLit(3)], line=20), line=20),
573 ],
574 line=7
575 )
576 return Program(declarations=[add_func, avg_func, main_func])
577
578
579def main():
580 print("=" * 60)
581 print("Symbol Table and Type Checker Demo")
582 print("=" * 60)
583
584 program = build_program()
585 tc = TypeChecker()
586 tc.check_program(program)
587
588 print("\nGlobal scope symbols:")
589 for name, sym in tc.table.global_scope._symbols.items():
590 if sym.is_function:
591 params = ', '.join(str(t) for t in sym.param_types)
592 print(f" function {name}({params}) -> {sym.return_type}")
593 else:
594 print(f" variable {name}: {sym.type}")
595
596 print(f"\nType checking complete.")
597 print(f"Errors found: {len(tc.errors)}")
598 if tc.errors:
599 print("\nError list:")
600 for err in tc.errors:
601 print(f" {err}")
602 else:
603 print(" No errors.")
604
605 # Demonstrate correct program
606 print("\n--- Correct program (no errors) ---")
607 correct = Program([
608 FuncDecl('square', 'int', [('n', 'int')],
609 [ReturnStmt(BinOp('*', VarRef('n'), VarRef('n')), line=1)], line=1),
610 FuncDecl('main', 'void', [],
611 [VarDecl('x', 'int', IntLit(5), line=3),
612 VarDecl('sq', 'int', CallExpr('square', [VarRef('x')], line=4), line=4),
613 VarDecl('ok', 'bool', BinOp('>', VarRef('sq'), IntLit(10), line=5), line=5)],
614 line=2)
615 ])
616 tc2 = TypeChecker()
617 tc2.check_program(correct)
618 print(f"Errors: {len(tc2.errors)}")
619 if tc2.errors:
620 for e in tc2.errors: print(f" {e}")
621 else:
622 print(" No errors. Program is type-correct.")
623
624
625if __name__ == "__main__":
626 main()