10_bytecode_vm.py

Download
python 661 lines 21.3 KB
  1"""
  210_bytecode_vm.py - Bytecode Compiler and Stack-Based Virtual Machine
  3
  4Demonstrates the final phases of a compiler:
  5  1. Bytecode compiler: translates an AST into a sequence of bytecode instructions
  6  2. Virtual machine: executes the bytecode on a stack-based architecture
  7
  8Architecture overview:
  9  - Stack machine: operands pushed, instructions consume/produce stack values
 10  - Operand stack: holds intermediate values during expression evaluation
 11  - Call stack (frames): local variables, return addresses, function arguments
 12  - Constant pool: stores literal values referenced by instructions
 13
 14Instruction set:
 15  PUSH <val>      Push a constant value onto the stack
 16  POP             Discard the top of stack
 17  LOAD <name>     Load a local variable onto the stack
 18  STORE <name>    Pop and store to a local variable
 19  ADD, SUB, MUL, DIV, MOD    Binary arithmetic
 20  NEG             Negate top of stack
 21  EQ, NE, LT, GT, LE, GE     Comparison (pushes 0 or 1)
 22  AND, OR, NOT    Logical operators
 23  JUMP <offset>   Unconditional jump (relative offset)
 24  JUMP_IF_FALSE <offset>   Pop and jump if zero/false
 25  CALL <name> <argc>  Call a function
 26  RETURN          Return from function
 27  PRINT           Pop and print (for demo output)
 28
 29Topics covered:
 30  - Bytecode instruction encoding
 31  - Stack-based expression evaluation
 32  - Function frames with local variables
 33  - Jump offsets for control flow
 34  - Bytecode disassembler
 35  - Recursive functions via call stack
 36"""
 37
 38from __future__ import annotations
 39from dataclasses import dataclass, field
 40from enum import Enum, auto
 41from typing import Any, Optional
 42
 43
 44# ---------------------------------------------------------------------------
 45# Instruction set
 46# ---------------------------------------------------------------------------
 47
 48class Op(Enum):
 49    PUSH           = auto()   # PUSH <value>
 50    POP            = auto()   # POP
 51    LOAD           = auto()   # LOAD <name>
 52    STORE          = auto()   # STORE <name>
 53    ADD            = auto()
 54    SUB            = auto()
 55    MUL            = auto()
 56    DIV            = auto()
 57    MOD            = auto()
 58    NEG            = auto()
 59    EQ             = auto()
 60    NE             = auto()
 61    LT             = auto()
 62    GT             = auto()
 63    LE             = auto()
 64    GE             = auto()
 65    AND            = auto()
 66    OR             = auto()
 67    NOT            = auto()
 68    JUMP           = auto()   # JUMP <target_ip>  (absolute)
 69    JUMP_IF_FALSE  = auto()   # JUMP_IF_FALSE <target_ip>
 70    CALL           = auto()   # CALL <func_name> <argc>
 71    RETURN         = auto()   # RETURN (uses top of stack as return value)
 72    RETURN_NONE    = auto()   # RETURN with no value (void)
 73    PRINT          = auto()   # Pop and print
 74    HALT           = auto()   # Stop execution
 75    DUP            = auto()   # Duplicate top of stack
 76
 77
 78@dataclass
 79class Instruction:
 80    op: Op
 81    arg1: Any = None   # First operand (name, value, or jump target)
 82    arg2: Any = None   # Second operand (e.g., argc for CALL)
 83
 84    def __repr__(self) -> str:
 85        if self.arg1 is not None and self.arg2 is not None:
 86            return f"{self.op.name:<18} {self.arg1!r:<16} {self.arg2!r}"
 87        elif self.arg1 is not None:
 88            return f"{self.op.name:<18} {self.arg1!r}"
 89        else:
 90            return f"{self.op.name}"
 91
 92
 93# ---------------------------------------------------------------------------
 94# Compiled function object
 95# ---------------------------------------------------------------------------
 96
 97@dataclass
 98class Function:
 99    name: str
100    params: list[str]
101    code: list[Instruction] = field(default_factory=list)
102
103    def __repr__(self):
104        return f"<Function {self.name}({', '.join(self.params)})>"
105
106
107# ---------------------------------------------------------------------------
108# Call frame
109# ---------------------------------------------------------------------------
110
111@dataclass
112class Frame:
113    """A single activation record (stack frame) for a function call."""
114    func: Function
115    ip: int = 0                         # instruction pointer
116    locals: dict[str, Any] = field(default_factory=dict)
117    return_value: Any = None
118
119
120# ---------------------------------------------------------------------------
121# AST node types (mini-language: same as 04_recursive_descent_parser)
122# ---------------------------------------------------------------------------
123
124@dataclass
125class NumLit:   value: Any
126@dataclass
127class StrLit:   value: str
128@dataclass
129class BoolLit:  value: bool
130@dataclass
131class Var:      name: str
132@dataclass
133class BinOp:    op: str; left: Any; right: Any
134@dataclass
135class UnaryOp:  op: str; operand: Any
136@dataclass
137class Assign:   name: str; value: Any
138@dataclass
139class IfStmt:   condition: Any; then_branch: Any; else_branch: Optional[Any] = None
140@dataclass
141class WhileStmt:condition: Any; body: Any
142@dataclass
143class PrintStmt:value: Any
144@dataclass
145class ReturnStmt: value: Optional[Any] = None
146@dataclass
147class Block:    stmts: list = field(default_factory=list)
148@dataclass
149class Program:  stmts: list = field(default_factory=list)
150@dataclass
151class FuncDef:  name: str; params: list; body: Any
152@dataclass
153class CallExpr: name: str; args: list = field(default_factory=list)
154
155
156# ---------------------------------------------------------------------------
157# Bytecode Compiler
158# ---------------------------------------------------------------------------
159
160class Compiler:
161    """
162    Compiles an AST into bytecode (a list of Instructions).
163    Handles top-level statements and function definitions.
164    """
165
166    def __init__(self):
167        self.functions: dict[str, Function] = {}
168        self._current: Optional[Function] = None
169
170    def _emit(self, op: Op, arg1=None, arg2=None) -> int:
171        """Emit an instruction and return its index."""
172        instr = Instruction(op, arg1, arg2)
173        self._current.code.append(instr)
174        return len(self._current.code) - 1
175
176    def _patch(self, idx: int, target: int) -> None:
177        """Patch a jump instruction's target address."""
178        self._current.code[idx].arg1 = target
179
180    def _ip(self) -> int:
181        """Current next instruction index."""
182        return len(self._current.code)
183
184    def compile_program(self, program: Program) -> None:
185        """
186        Two-pass compilation:
187          Pass 1: collect function definitions (so forward calls work)
188          Pass 2: compile the main body and all functions
189        """
190        # Pass 1: Register function signatures
191        func_defs = []
192        main_stmts = []
193        for stmt in program.stmts:
194            if isinstance(stmt, FuncDef):
195                func_defs.append(stmt)
196            else:
197                main_stmts.append(stmt)
198
199        # Create main function
200        main_func = Function(name='__main__', params=[])
201        self.functions['__main__'] = main_func
202
203        # Register user-defined functions
204        for fd in func_defs:
205            f = Function(name=fd.name, params=fd.params)
206            self.functions[fd.name] = f
207
208        # Compile function bodies
209        for fd in func_defs:
210            self._current = self.functions[fd.name]
211            self.compile_stmt(fd.body)
212            # Implicit void return
213            self._emit(Op.PUSH, None)
214            self._emit(Op.RETURN)
215
216        # Compile main body
217        self._current = main_func
218        for stmt in main_stmts:
219            self.compile_stmt(stmt)
220        self._emit(Op.HALT)
221
222    def compile_stmt(self, node) -> None:
223        match node:
224            case Program(stmts=stmts) | Block(stmts=stmts):
225                for s in stmts:
226                    self.compile_stmt(s)
227
228            case Assign(name=name, value=val):
229                self.compile_expr(val)
230                self._emit(Op.STORE, name)
231
232            case PrintStmt(value=val):
233                self.compile_expr(val)
234                self._emit(Op.PRINT)
235
236            case ReturnStmt(value=val):
237                if val is not None:
238                    self.compile_expr(val)
239                    self._emit(Op.RETURN)
240                else:
241                    self._emit(Op.PUSH, None)
242                    self._emit(Op.RETURN)
243
244            case IfStmt(condition=cond, then_branch=then_br, else_branch=else_br):
245                # Compile condition
246                self.compile_expr(cond)
247                # Emit JUMP_IF_FALSE (target patched later)
248                jif = self._emit(Op.JUMP_IF_FALSE, None)
249                # Compile then branch
250                self.compile_stmt(then_br)
251                if else_br:
252                    # Jump over else branch
253                    jmp = self._emit(Op.JUMP, None)
254                    # Patch jif to here (else branch start)
255                    self._patch(jif, self._ip())
256                    self.compile_stmt(else_br)
257                    # Patch jmp to here (after else)
258                    self._patch(jmp, self._ip())
259                else:
260                    self._patch(jif, self._ip())
261
262            case WhileStmt(condition=cond, body=body):
263                loop_start = self._ip()
264                self.compile_expr(cond)
265                jif = self._emit(Op.JUMP_IF_FALSE, None)   # exit loop
266                self.compile_stmt(body)
267                self._emit(Op.JUMP, loop_start)             # loop back
268                self._patch(jif, self._ip())               # patch exit
269
270            case FuncDef():
271                pass   # handled in compile_program
272
273            case _:
274                # Treat as expression statement
275                self.compile_expr(node)
276                self._emit(Op.POP)
277
278    def compile_expr(self, node) -> None:
279        match node:
280            case NumLit(value=v):
281                self._emit(Op.PUSH, v)
282
283            case StrLit(value=v):
284                self._emit(Op.PUSH, v)
285
286            case BoolLit(value=v):
287                self._emit(Op.PUSH, 1 if v else 0)
288
289            case Var(name=n):
290                self._emit(Op.LOAD, n)
291
292            case BinOp(op=op, left=left, right=right):
293                # Short-circuit for && and ||
294                if op == '&&':
295                    self.compile_expr(left)
296                    self._emit(Op.DUP)
297                    jif = self._emit(Op.JUMP_IF_FALSE, None)
298                    self._emit(Op.POP)
299                    self.compile_expr(right)
300                    self._patch(jif, self._ip())
301                    return
302                if op == '||':
303                    self.compile_expr(left)
304                    self._emit(Op.DUP)
305                    jif_true = self._emit(Op.JUMP_IF_FALSE, None)
306                    # left is truthy: jump past right
307                    jmp_end = self._emit(Op.JUMP, None)
308                    self._patch(jif_true, self._ip())
309                    self._emit(Op.POP)
310                    self.compile_expr(right)
311                    self._patch(jmp_end, self._ip())
312                    return
313
314                self.compile_expr(left)
315                self.compile_expr(right)
316                op_map = {
317                    '+': Op.ADD, '-': Op.SUB, '*': Op.MUL, '/': Op.DIV, '%': Op.MOD,
318                    '==': Op.EQ, '!=': Op.NE, '<': Op.LT, '>': Op.GT,
319                    '<=': Op.LE, '>=': Op.GE,
320                    '&&': Op.AND, '||': Op.OR,
321                }
322                self._emit(op_map[op])
323
324            case UnaryOp(op=op, operand=operand):
325                self.compile_expr(operand)
326                if op == '-':  self._emit(Op.NEG)
327                elif op == '!': self._emit(Op.NOT)
328
329            case CallExpr(name=name, args=args):
330                for arg in args:
331                    self.compile_expr(arg)
332                self._emit(Op.CALL, name, len(args))
333
334            case _:
335                raise ValueError(f"Unknown expression: {node!r}")
336
337
338# ---------------------------------------------------------------------------
339# Virtual Machine
340# ---------------------------------------------------------------------------
341
342class VMError(Exception):
343    pass
344
345
346class VM:
347    """
348    Stack-based virtual machine that executes bytecode.
349
350    Architecture:
351      - operand_stack: stack of values during expression evaluation
352      - call_stack: stack of Frame objects for function calls
353    """
354
355    def __init__(self, functions: dict[str, Function]):
356        self.functions = functions
357        self.operand_stack: list[Any] = []
358        self.call_stack: list[Frame] = []
359
360    def push(self, val: Any) -> None:
361        self.operand_stack.append(val)
362
363    def pop(self) -> Any:
364        if not self.operand_stack:
365            raise VMError("Stack underflow")
366        return self.operand_stack.pop()
367
368    def peek(self) -> Any:
369        if not self.operand_stack:
370            raise VMError("Stack empty")
371        return self.operand_stack[-1]
372
373    def run(self, func_name: str = '__main__') -> Any:
374        func = self.functions.get(func_name)
375        if func is None:
376            raise VMError(f"Function not found: {func_name!r}")
377        frame = Frame(func=func)
378        self.call_stack.append(frame)
379        return self._execute()
380
381    def _execute(self) -> Any:
382        MAX_INSTRUCTIONS = 100_000
383        count = 0
384
385        while self.call_stack:
386            frame = self.call_stack[-1]
387            if frame.ip >= len(frame.func.code):
388                # Implicit return
389                self.call_stack.pop()
390                self.push(None)
391                continue
392
393            instr = frame.func.code[frame.ip]
394            frame.ip += 1
395            count += 1
396            if count > MAX_INSTRUCTIONS:
397                raise VMError("Execution limit exceeded (infinite loop?)")
398
399            match instr.op:
400                case Op.PUSH:
401                    self.push(instr.arg1)
402
403                case Op.POP:
404                    self.pop()
405
406                case Op.DUP:
407                    self.push(self.peek())
408
409                case Op.LOAD:
410                    name = instr.arg1
411                    # Walk up call stack to find variable
412                    for f in reversed(self.call_stack):
413                        if name in f.locals:
414                            self.push(f.locals[name])
415                            break
416                    else:
417                        raise VMError(f"Undefined variable: {name!r}")
418
419                case Op.STORE:
420                    val = self.pop()
421                    frame.locals[instr.arg1] = val
422
423                case Op.ADD:
424                    b, a = self.pop(), self.pop()
425                    self.push(a + b)
426
427                case Op.SUB:
428                    b, a = self.pop(), self.pop()
429                    self.push(a - b)
430
431                case Op.MUL:
432                    b, a = self.pop(), self.pop()
433                    self.push(a * b)
434
435                case Op.DIV:
436                    b, a = self.pop(), self.pop()
437                    if b == 0: raise VMError("Division by zero")
438                    self.push(a // b if isinstance(a, int) and isinstance(b, int) else a / b)
439
440                case Op.MOD:
441                    b, a = self.pop(), self.pop()
442                    self.push(a % b)
443
444                case Op.NEG:
445                    self.push(-self.pop())
446
447                case Op.NOT:
448                    self.push(0 if self.pop() else 1)
449
450                case Op.EQ:
451                    b, a = self.pop(), self.pop()
452                    self.push(1 if a == b else 0)
453
454                case Op.NE:
455                    b, a = self.pop(), self.pop()
456                    self.push(1 if a != b else 0)
457
458                case Op.LT:
459                    b, a = self.pop(), self.pop()
460                    self.push(1 if a < b else 0)
461
462                case Op.GT:
463                    b, a = self.pop(), self.pop()
464                    self.push(1 if a > b else 0)
465
466                case Op.LE:
467                    b, a = self.pop(), self.pop()
468                    self.push(1 if a <= b else 0)
469
470                case Op.GE:
471                    b, a = self.pop(), self.pop()
472                    self.push(1 if a >= b else 0)
473
474                case Op.AND:
475                    b, a = self.pop(), self.pop()
476                    self.push(1 if (a and b) else 0)
477
478                case Op.OR:
479                    b, a = self.pop(), self.pop()
480                    self.push(1 if (a or b) else 0)
481
482                case Op.JUMP:
483                    frame.ip = instr.arg1
484
485                case Op.JUMP_IF_FALSE:
486                    cond = self.pop()
487                    if not cond:
488                        frame.ip = instr.arg1
489
490                case Op.CALL:
491                    func_name = instr.arg1
492                    argc = instr.arg2
493                    # Pop arguments in reverse order
494                    args = []
495                    for _ in range(argc):
496                        args.insert(0, self.pop())
497
498                    if func_name in self.functions:
499                        callee = self.functions[func_name]
500                        new_frame = Frame(func=callee)
501                        # Bind parameters
502                        for pname, pval in zip(callee.params, args):
503                            new_frame.locals[pname] = pval
504                        self.call_stack.append(new_frame)
505                    else:
506                        # Built-in functions
507                        result = self._call_builtin(func_name, args)
508                        self.push(result)
509
510                case Op.RETURN:
511                    retval = self.pop()
512                    self.call_stack.pop()
513                    self.push(retval)
514
515                case Op.RETURN_NONE:
516                    self.call_stack.pop()
517                    self.push(None)
518
519                case Op.PRINT:
520                    val = self.pop()
521                    print(f"  [vm output] {val}")
522
523                case Op.HALT:
524                    return self.operand_stack[-1] if self.operand_stack else None
525
526        return self.operand_stack[-1] if self.operand_stack else None
527
528    def _call_builtin(self, name: str, args: list) -> Any:
529        builtins = {
530            'abs':   lambda xs: abs(xs[0]),
531            'max':   lambda xs: max(xs),
532            'min':   lambda xs: min(xs),
533            'str':   lambda xs: str(xs[0]),
534            'int':   lambda xs: int(xs[0]),
535            'float': lambda xs: float(xs[0]),
536        }
537        if name not in builtins:
538            raise VMError(f"Unknown function: {name!r}")
539        return builtins[name](args)
540
541
542def disassemble(func: Function) -> None:
543    """Print a human-readable disassembly of a function's bytecode."""
544    print(f"Function: {func.name}({', '.join(func.params)})")
545    for i, instr in enumerate(func.code):
546        print(f"  {i:>4}  {instr}")
547    print()
548
549
550# ---------------------------------------------------------------------------
551# Demo programs
552# ---------------------------------------------------------------------------
553
554def demo1_arithmetic():
555    """Compile and run: x = 5; y = 3; print(x * y + 2)"""
556    print("--- Demo 1: Arithmetic ---")
557    prog = Program([
558        Assign('x', NumLit(5)),
559        Assign('y', NumLit(3)),
560        PrintStmt(BinOp('+', BinOp('*', Var('x'), Var('y')), NumLit(2))),
561    ])
562    compiler = Compiler()
563    compiler.compile_program(prog)
564    print("Bytecode:")
565    disassemble(compiler.functions['__main__'])
566    vm = VM(compiler.functions)
567    vm.run()
568
569
570def demo2_if_else():
571    """Compile and run: if/else statement"""
572    print("--- Demo 2: If/Else ---")
573    prog = Program([
574        Assign('x', NumLit(7)),
575        IfStmt(
576            BinOp('>', Var('x'), NumLit(5)),
577            Block([PrintStmt(StrLit("x > 5"))]),
578            Block([PrintStmt(StrLit("x <= 5"))]),
579        ),
580    ])
581    compiler = Compiler()
582    compiler.compile_program(prog)
583    print("Bytecode:")
584    disassemble(compiler.functions['__main__'])
585    vm = VM(compiler.functions)
586    vm.run()
587
588
589def demo3_while_loop():
590    """Compile and run: while loop computing sum 1..10"""
591    print("--- Demo 3: While Loop (sum 1..10) ---")
592    prog = Program([
593        Assign('i', NumLit(1)),
594        Assign('total', NumLit(0)),
595        WhileStmt(
596            BinOp('<=', Var('i'), NumLit(10)),
597            Block([
598                Assign('total', BinOp('+', Var('total'), Var('i'))),
599                Assign('i', BinOp('+', Var('i'), NumLit(1))),
600            ])
601        ),
602        PrintStmt(Var('total')),
603    ])
604    compiler = Compiler()
605    compiler.compile_program(prog)
606    print("Bytecode:")
607    disassemble(compiler.functions['__main__'])
608    vm = VM(compiler.functions)
609    vm.run()
610
611
612def demo4_function_call():
613    """Compile and run: recursive factorial function"""
614    print("--- Demo 4: Recursive Factorial ---")
615    # factorial(n) = if n <= 1 then 1 else n * factorial(n-1)
616    factorial_def = FuncDef(
617        name='factorial',
618        params=['n'],
619        body=Block([
620            IfStmt(
621                BinOp('<=', Var('n'), NumLit(1)),
622                Block([ReturnStmt(NumLit(1))]),
623                Block([ReturnStmt(
624                    BinOp('*', Var('n'),
625                          CallExpr('factorial', [BinOp('-', Var('n'), NumLit(1))]))
626                )]),
627            )
628        ])
629    )
630    prog = Program([
631        factorial_def,
632        Assign('result', CallExpr('factorial', [NumLit(6)])),
633        PrintStmt(Var('result')),
634    ])
635    compiler = Compiler()
636    compiler.compile_program(prog)
637    print("Bytecode for 'factorial':")
638    disassemble(compiler.functions['factorial'])
639    print("Bytecode for '__main__':")
640    disassemble(compiler.functions['__main__'])
641    vm = VM(compiler.functions)
642    vm.run()
643
644
645def main():
646    print("=" * 60)
647    print("Bytecode Compiler and Stack-Based VM Demo")
648    print("=" * 60)
649    print()
650    demo1_arithmetic()
651    print()
652    demo2_if_else()
653    print()
654    demo3_while_loop()
655    print()
656    demo4_function_call()
657
658
659if __name__ == "__main__":
660    main()