04_recursive_descent_parser.py

Download
python 498 lines 14.2 KB
  1"""
  204_recursive_descent_parser.py - Recursive Descent Parser
  3
  4Implements a hand-written recursive descent parser for a simple
  5imperative language. Builds an AST and provides evaluation.
  6
  7Grammar (simplified, left-recursion removed):
  8  program     ::= stmt*
  9  stmt        ::= if_stmt | while_stmt | print_stmt | assign_stmt | block
 10  if_stmt     ::= 'if' '(' expr ')' stmt ('else' stmt)?
 11  while_stmt  ::= 'while' '(' expr ')' stmt
 12  print_stmt  ::= 'print' '(' expr ')' ';'
 13  assign_stmt ::= IDENT '=' expr ';'
 14  block       ::= '{' stmt* '}'
 15  expr        ::= or_expr
 16  or_expr     ::= and_expr ('||' and_expr)*
 17  and_expr    ::= eq_expr ('&&' eq_expr)*
 18  eq_expr     ::= rel_expr (('=='|'!=') rel_expr)*
 19  rel_expr    ::= add_expr (('<'|'>'|'<='|'>=') add_expr)*
 20  add_expr    ::= mul_expr (('+'|'-') mul_expr)*
 21  mul_expr    ::= unary (('*'|'/') unary)*
 22  unary       ::= ('-'|'!') unary | primary
 23  primary     ::= INT | FLOAT | STRING | BOOL | IDENT | '(' expr ')'
 24
 25Topics covered:
 26  - Recursive descent parsing
 27  - Abstract Syntax Tree (AST) construction
 28  - Operator precedence via grammar layers
 29  - Tree evaluation (interpreter)
 30  - AST pretty-printing
 31"""
 32
 33from __future__ import annotations
 34import re
 35from dataclasses import dataclass, field
 36from typing import Any, Optional
 37
 38
 39# ---------------------------------------------------------------------------
 40# Lexer (minimal, reused from 01_lexer concepts)
 41# ---------------------------------------------------------------------------
 42
 43@dataclass
 44class Token:
 45    type: str
 46    value: str
 47    line: int
 48
 49PATTERNS = [
 50    ('FLOAT',   r'\d+\.\d*'),
 51    ('INT',     r'\d+'),
 52    ('STRING',  r'"[^"]*"'),
 53    ('BOOL',    r'\b(true|false)\b'),
 54    ('KW',      r'\b(if|else|while|print)\b'),
 55    ('IDENT',   r'[A-Za-z_]\w*'),
 56    ('OP',      r'==|!=|<=|>=|&&|\|\||[+\-*/<>=!]'),
 57    ('PUNCT',   r'[(){};,]'),
 58    ('WS',      r'\s+'),
 59]
 60_LEX_RE = re.compile('|'.join(f'(?P<{name}>{pat})' for name, pat in PATTERNS))
 61
 62
 63def tokenize(source: str) -> list[Token]:
 64    tokens = []
 65    line = 1
 66    for m in _LEX_RE.finditer(source):
 67        kind = m.lastgroup
 68        val = m.group()
 69        if kind == 'WS':
 70            line += val.count('\n')
 71            continue
 72        if kind == 'KW':
 73            tokens.append(Token(val, val, line))  # keyword type IS the word
 74        elif kind == 'BOOL':
 75            tokens.append(Token('BOOL', val, line))
 76        else:
 77            tokens.append(Token(kind, val, line))
 78    tokens.append(Token('EOF', '', line))
 79    return tokens
 80
 81
 82# ---------------------------------------------------------------------------
 83# AST Node definitions
 84# ---------------------------------------------------------------------------
 85
 86@dataclass
 87class NumLit:
 88    value: float
 89    is_int: bool = True
 90
 91@dataclass
 92class StrLit:
 93    value: str
 94
 95@dataclass
 96class BoolLit:
 97    value: bool
 98
 99@dataclass
100class Var:
101    name: str
102
103@dataclass
104class BinOp:
105    op: str
106    left: Any
107    right: Any
108
109@dataclass
110class UnaryOp:
111    op: str
112    operand: Any
113
114@dataclass
115class Assign:
116    name: str
117    value: Any
118
119@dataclass
120class IfStmt:
121    condition: Any
122    then_branch: Any
123    else_branch: Optional[Any] = None
124
125@dataclass
126class WhileStmt:
127    condition: Any
128    body: Any
129
130@dataclass
131class PrintStmt:
132    value: Any
133
134@dataclass
135class Block:
136    stmts: list = field(default_factory=list)
137
138@dataclass
139class Program:
140    stmts: list = field(default_factory=list)
141
142
143# ---------------------------------------------------------------------------
144# Parser
145# ---------------------------------------------------------------------------
146
147class ParseError(Exception):
148    def __init__(self, msg: str, token: Token):
149        super().__init__(f"Parse error at line {token.line}: {msg} (got {token.type!r} = {token.value!r})")
150
151
152class Parser:
153    def __init__(self, tokens: list[Token]):
154        self.tokens = tokens
155        self.pos = 0
156
157    def peek(self) -> Token:
158        return self.tokens[self.pos]
159
160    def advance(self) -> Token:
161        tok = self.tokens[self.pos]
162        self.pos += 1
163        return tok
164
165    def expect(self, type_: str, value: str = None) -> Token:
166        tok = self.peek()
167        if tok.type != type_:
168            raise ParseError(f"Expected {type_!r}", tok)
169        if value is not None and tok.value != value:
170            raise ParseError(f"Expected {value!r}", tok)
171        return self.advance()
172
173    def match(self, type_: str, value: str = None) -> bool:
174        tok = self.peek()
175        if tok.type != type_:
176            return False
177        if value is not None and tok.value != value:
178            return False
179        return True
180
181    # --- Statements ---
182
183    def parse_program(self) -> Program:
184        stmts = []
185        while not self.match('EOF'):
186            stmts.append(self.parse_stmt())
187        return Program(stmts)
188
189    def parse_stmt(self):
190        tok = self.peek()
191        if tok.type == 'if':
192            return self.parse_if()
193        elif tok.type == 'while':
194            return self.parse_while()
195        elif tok.type == 'print':
196            return self.parse_print()
197        elif tok.type == 'PUNCT' and tok.value == '{':
198            return self.parse_block()
199        elif tok.type == 'IDENT':
200            return self.parse_assign()
201        else:
202            raise ParseError("Expected statement", tok)
203
204    def parse_if(self) -> IfStmt:
205        self.expect('if')
206        self.expect('PUNCT', '(')
207        cond = self.parse_expr()
208        self.expect('PUNCT', ')')
209        then_br = self.parse_stmt()
210        else_br = None
211        if self.match('else'):
212            self.advance()
213            else_br = self.parse_stmt()
214        return IfStmt(cond, then_br, else_br)
215
216    def parse_while(self) -> WhileStmt:
217        self.expect('while')
218        self.expect('PUNCT', '(')
219        cond = self.parse_expr()
220        self.expect('PUNCT', ')')
221        body = self.parse_stmt()
222        return WhileStmt(cond, body)
223
224    def parse_print(self) -> PrintStmt:
225        self.expect('print')
226        self.expect('PUNCT', '(')
227        val = self.parse_expr()
228        self.expect('PUNCT', ')')
229        self.expect('PUNCT', ';')
230        return PrintStmt(val)
231
232    def parse_assign(self) -> Assign:
233        name = self.expect('IDENT').value
234        self.expect('OP', '=')
235        val = self.parse_expr()
236        self.expect('PUNCT', ';')
237        return Assign(name, val)
238
239    def parse_block(self) -> Block:
240        self.expect('PUNCT', '{')
241        stmts = []
242        while not (self.match('PUNCT', '}') or self.match('EOF')):
243            stmts.append(self.parse_stmt())
244        self.expect('PUNCT', '}')
245        return Block(stmts)
246
247    # --- Expressions (precedence climbing via recursive functions) ---
248
249    def parse_expr(self):
250        return self.parse_or()
251
252    def parse_or(self):
253        left = self.parse_and()
254        while self.match('OP', '||'):
255            op = self.advance().value
256            left = BinOp(op, left, self.parse_and())
257        return left
258
259    def parse_and(self):
260        left = self.parse_eq()
261        while self.match('OP', '&&'):
262            op = self.advance().value
263            left = BinOp(op, left, self.parse_eq())
264        return left
265
266    def parse_eq(self):
267        left = self.parse_rel()
268        while self.peek().type == 'OP' and self.peek().value in ('==', '!='):
269            op = self.advance().value
270            left = BinOp(op, left, self.parse_rel())
271        return left
272
273    def parse_rel(self):
274        left = self.parse_add()
275        while self.peek().type == 'OP' and self.peek().value in ('<', '>', '<=', '>='):
276            op = self.advance().value
277            left = BinOp(op, left, self.parse_add())
278        return left
279
280    def parse_add(self):
281        left = self.parse_mul()
282        while self.peek().type == 'OP' and self.peek().value in ('+', '-'):
283            op = self.advance().value
284            left = BinOp(op, left, self.parse_mul())
285        return left
286
287    def parse_mul(self):
288        left = self.parse_unary()
289        while self.peek().type == 'OP' and self.peek().value in ('*', '/'):
290            op = self.advance().value
291            left = BinOp(op, left, self.parse_unary())
292        return left
293
294    def parse_unary(self):
295        if self.peek().type == 'OP' and self.peek().value in ('-', '!'):
296            op = self.advance().value
297            return UnaryOp(op, self.parse_unary())
298        return self.parse_primary()
299
300    def parse_primary(self):
301        tok = self.peek()
302        if tok.type == 'INT':
303            self.advance()
304            return NumLit(int(tok.value), is_int=True)
305        elif tok.type == 'FLOAT':
306            self.advance()
307            return NumLit(float(tok.value), is_int=False)
308        elif tok.type == 'STRING':
309            self.advance()
310            return StrLit(tok.value[1:-1])   # strip quotes
311        elif tok.type == 'BOOL':
312            self.advance()
313            return BoolLit(tok.value == 'true')
314        elif tok.type == 'IDENT':
315            self.advance()
316            return Var(tok.value)
317        elif tok.type == 'PUNCT' and tok.value == '(':
318            self.advance()
319            expr = self.parse_expr()
320            self.expect('PUNCT', ')')
321            return expr
322        else:
323            raise ParseError("Expected primary expression", tok)
324
325
326# ---------------------------------------------------------------------------
327# Evaluator (tree-walk interpreter)
328# ---------------------------------------------------------------------------
329
330class EvalError(Exception):
331    pass
332
333
334class Evaluator:
335    def __init__(self):
336        self.env: dict[str, Any] = {}
337
338    def eval(self, node) -> Any:
339        match node:
340            case Program(stmts=stmts) | Block(stmts=stmts):
341                result = None
342                for s in stmts:
343                    result = self.eval(s)
344                return result
345            case Assign(name=name, value=val):
346                self.env[name] = self.eval(val)
347                return self.env[name]
348            case IfStmt(condition=cond, then_branch=then_br, else_branch=else_br):
349                if self.eval(cond):
350                    return self.eval(then_br)
351                elif else_br:
352                    return self.eval(else_br)
353            case WhileStmt(condition=cond, body=body):
354                while self.eval(cond):
355                    self.eval(body)
356            case PrintStmt(value=val):
357                v = self.eval(val)
358                print(f"  [output] {v}")
359                return v
360            case NumLit(value=v):
361                return v
362            case StrLit(value=v):
363                return v
364            case BoolLit(value=v):
365                return v
366            case Var(name=name):
367                if name not in self.env:
368                    raise EvalError(f"Undefined variable: {name!r}")
369                return self.env[name]
370            case BinOp(op=op, left=left, right=right):
371                l, r = self.eval(left), self.eval(right)
372                match op:
373                    case '+':  return l + r
374                    case '-':  return l - r
375                    case '*':  return l * r
376                    case '/':  return l / r if r != 0 else (raise_(EvalError("Division by zero")))
377                    case '<':  return l < r
378                    case '>':  return l > r
379                    case '<=': return l <= r
380                    case '>=': return l >= r
381                    case '==': return l == r
382                    case '!=': return l != r
383                    case '&&': return bool(l) and bool(r)
384                    case '||': return bool(l) or bool(r)
385            case UnaryOp(op=op, operand=operand):
386                v = self.eval(operand)
387                if op == '-': return -v
388                if op == '!': return not v
389        return None
390
391
392def raise_(exc):
393    raise exc
394
395
396# ---------------------------------------------------------------------------
397# AST Pretty Printer
398# ---------------------------------------------------------------------------
399
400def pprint(node, indent: int = 0) -> None:
401    pad = "  " * indent
402    match node:
403        case Program(stmts=stmts):
404            print(f"{pad}Program")
405            for s in stmts: pprint(s, indent+1)
406        case Block(stmts=stmts):
407            print(f"{pad}Block")
408            for s in stmts: pprint(s, indent+1)
409        case Assign(name=n, value=v):
410            print(f"{pad}Assign({n!r})")
411            pprint(v, indent+1)
412        case IfStmt(condition=c, then_branch=t, else_branch=e):
413            print(f"{pad}If")
414            print(f"{pad}  cond:"); pprint(c, indent+2)
415            print(f"{pad}  then:"); pprint(t, indent+2)
416            if e: print(f"{pad}  else:"); pprint(e, indent+2)
417        case WhileStmt(condition=c, body=b):
418            print(f"{pad}While")
419            print(f"{pad}  cond:"); pprint(c, indent+2)
420            print(f"{pad}  body:"); pprint(b, indent+2)
421        case PrintStmt(value=v):
422            print(f"{pad}Print"); pprint(v, indent+1)
423        case BinOp(op=op, left=l, right=r):
424            print(f"{pad}BinOp({op!r})")
425            pprint(l, indent+1); pprint(r, indent+1)
426        case UnaryOp(op=op, operand=o):
427            print(f"{pad}Unary({op!r})"); pprint(o, indent+1)
428        case NumLit(value=v): print(f"{pad}Num({v})")
429        case StrLit(value=v): print(f"{pad}Str({v!r})")
430        case BoolLit(value=v): print(f"{pad}Bool({v})")
431        case Var(name=n): print(f"{pad}Var({n!r})")
432        case _: print(f"{pad}{node!r}")
433
434
435# ---------------------------------------------------------------------------
436# Demo
437# ---------------------------------------------------------------------------
438
439PROGRAM_1 = """\
440x = 10;
441y = 3;
442z = x * y + 2;
443print(z);
444"""
445
446PROGRAM_2 = """\
447n = 10;
448fib_a = 0;
449fib_b = 1;
450i = 2;
451while (i <= n) {
452    tmp = fib_a + fib_b;
453    fib_a = fib_b;
454    fib_b = tmp;
455    i = i + 1;
456}
457print(fib_b);
458"""
459
460PROGRAM_3 = """\
461x = 7;
462if (x > 5) {
463    print("x is greater than 5");
464} else {
465    print("x is not greater than 5");
466}
467"""
468
469
470def demo(label: str, source: str) -> None:
471    print(f"\n{'─'*56}")
472    print(f"Demo: {label}")
473    print(f"Source:\n{source}")
474    tokens = tokenize(source)
475    parser = Parser(tokens)
476    ast = parser.parse_program()
477
478    print("AST:")
479    pprint(ast)
480
481    print("\nEvaluation:")
482    ev = Evaluator()
483    ev.eval(ast)
484    print(f"Final env: {ev.env}")
485
486
487def main():
488    print("=" * 60)
489    print("Recursive Descent Parser Demo")
490    print("=" * 60)
491    demo("Arithmetic", PROGRAM_1)
492    demo("Fibonacci (while loop)", PROGRAM_2)
493    demo("If/Else with strings", PROGRAM_3)
494
495
496if __name__ == "__main__":
497    main()