1"""
210_bytecode_vm.py - Bytecode Compiler and Stack-Based Virtual Machine
3
4Demonstrates the final phases of a compiler:
5 1. Bytecode compiler: translates an AST into a sequence of bytecode instructions
6 2. Virtual machine: executes the bytecode on a stack-based architecture
7
8Architecture overview:
9 - Stack machine: operands pushed, instructions consume/produce stack values
10 - Operand stack: holds intermediate values during expression evaluation
11 - Call stack (frames): local variables, return addresses, function arguments
12 - Constant pool: stores literal values referenced by instructions
13
14Instruction set:
15 PUSH <val> Push a constant value onto the stack
16 POP Discard the top of stack
17 LOAD <name> Load a local variable onto the stack
18 STORE <name> Pop and store to a local variable
19 ADD, SUB, MUL, DIV, MOD Binary arithmetic
20 NEG Negate top of stack
21 EQ, NE, LT, GT, LE, GE Comparison (pushes 0 or 1)
22 AND, OR, NOT Logical operators
23 JUMP <offset> Unconditional jump (relative offset)
24 JUMP_IF_FALSE <offset> Pop and jump if zero/false
25 CALL <name> <argc> Call a function
26 RETURN Return from function
27 PRINT Pop and print (for demo output)
28
29Topics covered:
30 - Bytecode instruction encoding
31 - Stack-based expression evaluation
32 - Function frames with local variables
33 - Jump offsets for control flow
34 - Bytecode disassembler
35 - Recursive functions via call stack
36"""
37
38from __future__ import annotations
39from dataclasses import dataclass, field
40from enum import Enum, auto
41from typing import Any, Optional
42
43
44# ---------------------------------------------------------------------------
45# Instruction set
46# ---------------------------------------------------------------------------
47
48class Op(Enum):
49 PUSH = auto() # PUSH <value>
50 POP = auto() # POP
51 LOAD = auto() # LOAD <name>
52 STORE = auto() # STORE <name>
53 ADD = auto()
54 SUB = auto()
55 MUL = auto()
56 DIV = auto()
57 MOD = auto()
58 NEG = auto()
59 EQ = auto()
60 NE = auto()
61 LT = auto()
62 GT = auto()
63 LE = auto()
64 GE = auto()
65 AND = auto()
66 OR = auto()
67 NOT = auto()
68 JUMP = auto() # JUMP <target_ip> (absolute)
69 JUMP_IF_FALSE = auto() # JUMP_IF_FALSE <target_ip>
70 CALL = auto() # CALL <func_name> <argc>
71 RETURN = auto() # RETURN (uses top of stack as return value)
72 RETURN_NONE = auto() # RETURN with no value (void)
73 PRINT = auto() # Pop and print
74 HALT = auto() # Stop execution
75 DUP = auto() # Duplicate top of stack
76
77
78@dataclass
79class Instruction:
80 op: Op
81 arg1: Any = None # First operand (name, value, or jump target)
82 arg2: Any = None # Second operand (e.g., argc for CALL)
83
84 def __repr__(self) -> str:
85 if self.arg1 is not None and self.arg2 is not None:
86 return f"{self.op.name:<18} {self.arg1!r:<16} {self.arg2!r}"
87 elif self.arg1 is not None:
88 return f"{self.op.name:<18} {self.arg1!r}"
89 else:
90 return f"{self.op.name}"
91
92
93# ---------------------------------------------------------------------------
94# Compiled function object
95# ---------------------------------------------------------------------------
96
97@dataclass
98class Function:
99 name: str
100 params: list[str]
101 code: list[Instruction] = field(default_factory=list)
102
103 def __repr__(self):
104 return f"<Function {self.name}({', '.join(self.params)})>"
105
106
107# ---------------------------------------------------------------------------
108# Call frame
109# ---------------------------------------------------------------------------
110
111@dataclass
112class Frame:
113 """A single activation record (stack frame) for a function call."""
114 func: Function
115 ip: int = 0 # instruction pointer
116 locals: dict[str, Any] = field(default_factory=dict)
117 return_value: Any = None
118
119
120# ---------------------------------------------------------------------------
121# AST node types (mini-language: same as 04_recursive_descent_parser)
122# ---------------------------------------------------------------------------
123
124@dataclass
125class NumLit: value: Any
126@dataclass
127class StrLit: value: str
128@dataclass
129class BoolLit: value: bool
130@dataclass
131class Var: name: str
132@dataclass
133class BinOp: op: str; left: Any; right: Any
134@dataclass
135class UnaryOp: op: str; operand: Any
136@dataclass
137class Assign: name: str; value: Any
138@dataclass
139class IfStmt: condition: Any; then_branch: Any; else_branch: Optional[Any] = None
140@dataclass
141class WhileStmt:condition: Any; body: Any
142@dataclass
143class PrintStmt:value: Any
144@dataclass
145class ReturnStmt: value: Optional[Any] = None
146@dataclass
147class Block: stmts: list = field(default_factory=list)
148@dataclass
149class Program: stmts: list = field(default_factory=list)
150@dataclass
151class FuncDef: name: str; params: list; body: Any
152@dataclass
153class CallExpr: name: str; args: list = field(default_factory=list)
154
155
156# ---------------------------------------------------------------------------
157# Bytecode Compiler
158# ---------------------------------------------------------------------------
159
160class Compiler:
161 """
162 Compiles an AST into bytecode (a list of Instructions).
163 Handles top-level statements and function definitions.
164 """
165
166 def __init__(self):
167 self.functions: dict[str, Function] = {}
168 self._current: Optional[Function] = None
169
170 def _emit(self, op: Op, arg1=None, arg2=None) -> int:
171 """Emit an instruction and return its index."""
172 instr = Instruction(op, arg1, arg2)
173 self._current.code.append(instr)
174 return len(self._current.code) - 1
175
176 def _patch(self, idx: int, target: int) -> None:
177 """Patch a jump instruction's target address."""
178 self._current.code[idx].arg1 = target
179
180 def _ip(self) -> int:
181 """Current next instruction index."""
182 return len(self._current.code)
183
184 def compile_program(self, program: Program) -> None:
185 """
186 Two-pass compilation:
187 Pass 1: collect function definitions (so forward calls work)
188 Pass 2: compile the main body and all functions
189 """
190 # Pass 1: Register function signatures
191 func_defs = []
192 main_stmts = []
193 for stmt in program.stmts:
194 if isinstance(stmt, FuncDef):
195 func_defs.append(stmt)
196 else:
197 main_stmts.append(stmt)
198
199 # Create main function
200 main_func = Function(name='__main__', params=[])
201 self.functions['__main__'] = main_func
202
203 # Register user-defined functions
204 for fd in func_defs:
205 f = Function(name=fd.name, params=fd.params)
206 self.functions[fd.name] = f
207
208 # Compile function bodies
209 for fd in func_defs:
210 self._current = self.functions[fd.name]
211 self.compile_stmt(fd.body)
212 # Implicit void return
213 self._emit(Op.PUSH, None)
214 self._emit(Op.RETURN)
215
216 # Compile main body
217 self._current = main_func
218 for stmt in main_stmts:
219 self.compile_stmt(stmt)
220 self._emit(Op.HALT)
221
222 def compile_stmt(self, node) -> None:
223 match node:
224 case Program(stmts=stmts) | Block(stmts=stmts):
225 for s in stmts:
226 self.compile_stmt(s)
227
228 case Assign(name=name, value=val):
229 self.compile_expr(val)
230 self._emit(Op.STORE, name)
231
232 case PrintStmt(value=val):
233 self.compile_expr(val)
234 self._emit(Op.PRINT)
235
236 case ReturnStmt(value=val):
237 if val is not None:
238 self.compile_expr(val)
239 self._emit(Op.RETURN)
240 else:
241 self._emit(Op.PUSH, None)
242 self._emit(Op.RETURN)
243
244 case IfStmt(condition=cond, then_branch=then_br, else_branch=else_br):
245 # Compile condition
246 self.compile_expr(cond)
247 # Emit JUMP_IF_FALSE (target patched later)
248 jif = self._emit(Op.JUMP_IF_FALSE, None)
249 # Compile then branch
250 self.compile_stmt(then_br)
251 if else_br:
252 # Jump over else branch
253 jmp = self._emit(Op.JUMP, None)
254 # Patch jif to here (else branch start)
255 self._patch(jif, self._ip())
256 self.compile_stmt(else_br)
257 # Patch jmp to here (after else)
258 self._patch(jmp, self._ip())
259 else:
260 self._patch(jif, self._ip())
261
262 case WhileStmt(condition=cond, body=body):
263 loop_start = self._ip()
264 self.compile_expr(cond)
265 jif = self._emit(Op.JUMP_IF_FALSE, None) # exit loop
266 self.compile_stmt(body)
267 self._emit(Op.JUMP, loop_start) # loop back
268 self._patch(jif, self._ip()) # patch exit
269
270 case FuncDef():
271 pass # handled in compile_program
272
273 case _:
274 # Treat as expression statement
275 self.compile_expr(node)
276 self._emit(Op.POP)
277
278 def compile_expr(self, node) -> None:
279 match node:
280 case NumLit(value=v):
281 self._emit(Op.PUSH, v)
282
283 case StrLit(value=v):
284 self._emit(Op.PUSH, v)
285
286 case BoolLit(value=v):
287 self._emit(Op.PUSH, 1 if v else 0)
288
289 case Var(name=n):
290 self._emit(Op.LOAD, n)
291
292 case BinOp(op=op, left=left, right=right):
293 # Short-circuit for && and ||
294 if op == '&&':
295 self.compile_expr(left)
296 self._emit(Op.DUP)
297 jif = self._emit(Op.JUMP_IF_FALSE, None)
298 self._emit(Op.POP)
299 self.compile_expr(right)
300 self._patch(jif, self._ip())
301 return
302 if op == '||':
303 self.compile_expr(left)
304 self._emit(Op.DUP)
305 jif_true = self._emit(Op.JUMP_IF_FALSE, None)
306 # left is truthy: jump past right
307 jmp_end = self._emit(Op.JUMP, None)
308 self._patch(jif_true, self._ip())
309 self._emit(Op.POP)
310 self.compile_expr(right)
311 self._patch(jmp_end, self._ip())
312 return
313
314 self.compile_expr(left)
315 self.compile_expr(right)
316 op_map = {
317 '+': Op.ADD, '-': Op.SUB, '*': Op.MUL, '/': Op.DIV, '%': Op.MOD,
318 '==': Op.EQ, '!=': Op.NE, '<': Op.LT, '>': Op.GT,
319 '<=': Op.LE, '>=': Op.GE,
320 '&&': Op.AND, '||': Op.OR,
321 }
322 self._emit(op_map[op])
323
324 case UnaryOp(op=op, operand=operand):
325 self.compile_expr(operand)
326 if op == '-': self._emit(Op.NEG)
327 elif op == '!': self._emit(Op.NOT)
328
329 case CallExpr(name=name, args=args):
330 for arg in args:
331 self.compile_expr(arg)
332 self._emit(Op.CALL, name, len(args))
333
334 case _:
335 raise ValueError(f"Unknown expression: {node!r}")
336
337
338# ---------------------------------------------------------------------------
339# Virtual Machine
340# ---------------------------------------------------------------------------
341
342class VMError(Exception):
343 pass
344
345
346class VM:
347 """
348 Stack-based virtual machine that executes bytecode.
349
350 Architecture:
351 - operand_stack: stack of values during expression evaluation
352 - call_stack: stack of Frame objects for function calls
353 """
354
355 def __init__(self, functions: dict[str, Function]):
356 self.functions = functions
357 self.operand_stack: list[Any] = []
358 self.call_stack: list[Frame] = []
359
360 def push(self, val: Any) -> None:
361 self.operand_stack.append(val)
362
363 def pop(self) -> Any:
364 if not self.operand_stack:
365 raise VMError("Stack underflow")
366 return self.operand_stack.pop()
367
368 def peek(self) -> Any:
369 if not self.operand_stack:
370 raise VMError("Stack empty")
371 return self.operand_stack[-1]
372
373 def run(self, func_name: str = '__main__') -> Any:
374 func = self.functions.get(func_name)
375 if func is None:
376 raise VMError(f"Function not found: {func_name!r}")
377 frame = Frame(func=func)
378 self.call_stack.append(frame)
379 return self._execute()
380
381 def _execute(self) -> Any:
382 MAX_INSTRUCTIONS = 100_000
383 count = 0
384
385 while self.call_stack:
386 frame = self.call_stack[-1]
387 if frame.ip >= len(frame.func.code):
388 # Implicit return
389 self.call_stack.pop()
390 self.push(None)
391 continue
392
393 instr = frame.func.code[frame.ip]
394 frame.ip += 1
395 count += 1
396 if count > MAX_INSTRUCTIONS:
397 raise VMError("Execution limit exceeded (infinite loop?)")
398
399 match instr.op:
400 case Op.PUSH:
401 self.push(instr.arg1)
402
403 case Op.POP:
404 self.pop()
405
406 case Op.DUP:
407 self.push(self.peek())
408
409 case Op.LOAD:
410 name = instr.arg1
411 # Walk up call stack to find variable
412 for f in reversed(self.call_stack):
413 if name in f.locals:
414 self.push(f.locals[name])
415 break
416 else:
417 raise VMError(f"Undefined variable: {name!r}")
418
419 case Op.STORE:
420 val = self.pop()
421 frame.locals[instr.arg1] = val
422
423 case Op.ADD:
424 b, a = self.pop(), self.pop()
425 self.push(a + b)
426
427 case Op.SUB:
428 b, a = self.pop(), self.pop()
429 self.push(a - b)
430
431 case Op.MUL:
432 b, a = self.pop(), self.pop()
433 self.push(a * b)
434
435 case Op.DIV:
436 b, a = self.pop(), self.pop()
437 if b == 0: raise VMError("Division by zero")
438 self.push(a // b if isinstance(a, int) and isinstance(b, int) else a / b)
439
440 case Op.MOD:
441 b, a = self.pop(), self.pop()
442 self.push(a % b)
443
444 case Op.NEG:
445 self.push(-self.pop())
446
447 case Op.NOT:
448 self.push(0 if self.pop() else 1)
449
450 case Op.EQ:
451 b, a = self.pop(), self.pop()
452 self.push(1 if a == b else 0)
453
454 case Op.NE:
455 b, a = self.pop(), self.pop()
456 self.push(1 if a != b else 0)
457
458 case Op.LT:
459 b, a = self.pop(), self.pop()
460 self.push(1 if a < b else 0)
461
462 case Op.GT:
463 b, a = self.pop(), self.pop()
464 self.push(1 if a > b else 0)
465
466 case Op.LE:
467 b, a = self.pop(), self.pop()
468 self.push(1 if a <= b else 0)
469
470 case Op.GE:
471 b, a = self.pop(), self.pop()
472 self.push(1 if a >= b else 0)
473
474 case Op.AND:
475 b, a = self.pop(), self.pop()
476 self.push(1 if (a and b) else 0)
477
478 case Op.OR:
479 b, a = self.pop(), self.pop()
480 self.push(1 if (a or b) else 0)
481
482 case Op.JUMP:
483 frame.ip = instr.arg1
484
485 case Op.JUMP_IF_FALSE:
486 cond = self.pop()
487 if not cond:
488 frame.ip = instr.arg1
489
490 case Op.CALL:
491 func_name = instr.arg1
492 argc = instr.arg2
493 # Pop arguments in reverse order
494 args = []
495 for _ in range(argc):
496 args.insert(0, self.pop())
497
498 if func_name in self.functions:
499 callee = self.functions[func_name]
500 new_frame = Frame(func=callee)
501 # Bind parameters
502 for pname, pval in zip(callee.params, args):
503 new_frame.locals[pname] = pval
504 self.call_stack.append(new_frame)
505 else:
506 # Built-in functions
507 result = self._call_builtin(func_name, args)
508 self.push(result)
509
510 case Op.RETURN:
511 retval = self.pop()
512 self.call_stack.pop()
513 self.push(retval)
514
515 case Op.RETURN_NONE:
516 self.call_stack.pop()
517 self.push(None)
518
519 case Op.PRINT:
520 val = self.pop()
521 print(f" [vm output] {val}")
522
523 case Op.HALT:
524 return self.operand_stack[-1] if self.operand_stack else None
525
526 return self.operand_stack[-1] if self.operand_stack else None
527
528 def _call_builtin(self, name: str, args: list) -> Any:
529 builtins = {
530 'abs': lambda xs: abs(xs[0]),
531 'max': lambda xs: max(xs),
532 'min': lambda xs: min(xs),
533 'str': lambda xs: str(xs[0]),
534 'int': lambda xs: int(xs[0]),
535 'float': lambda xs: float(xs[0]),
536 }
537 if name not in builtins:
538 raise VMError(f"Unknown function: {name!r}")
539 return builtins[name](args)
540
541
542def disassemble(func: Function) -> None:
543 """Print a human-readable disassembly of a function's bytecode."""
544 print(f"Function: {func.name}({', '.join(func.params)})")
545 for i, instr in enumerate(func.code):
546 print(f" {i:>4} {instr}")
547 print()
548
549
550# ---------------------------------------------------------------------------
551# Demo programs
552# ---------------------------------------------------------------------------
553
554def demo1_arithmetic():
555 """Compile and run: x = 5; y = 3; print(x * y + 2)"""
556 print("--- Demo 1: Arithmetic ---")
557 prog = Program([
558 Assign('x', NumLit(5)),
559 Assign('y', NumLit(3)),
560 PrintStmt(BinOp('+', BinOp('*', Var('x'), Var('y')), NumLit(2))),
561 ])
562 compiler = Compiler()
563 compiler.compile_program(prog)
564 print("Bytecode:")
565 disassemble(compiler.functions['__main__'])
566 vm = VM(compiler.functions)
567 vm.run()
568
569
570def demo2_if_else():
571 """Compile and run: if/else statement"""
572 print("--- Demo 2: If/Else ---")
573 prog = Program([
574 Assign('x', NumLit(7)),
575 IfStmt(
576 BinOp('>', Var('x'), NumLit(5)),
577 Block([PrintStmt(StrLit("x > 5"))]),
578 Block([PrintStmt(StrLit("x <= 5"))]),
579 ),
580 ])
581 compiler = Compiler()
582 compiler.compile_program(prog)
583 print("Bytecode:")
584 disassemble(compiler.functions['__main__'])
585 vm = VM(compiler.functions)
586 vm.run()
587
588
589def demo3_while_loop():
590 """Compile and run: while loop computing sum 1..10"""
591 print("--- Demo 3: While Loop (sum 1..10) ---")
592 prog = Program([
593 Assign('i', NumLit(1)),
594 Assign('total', NumLit(0)),
595 WhileStmt(
596 BinOp('<=', Var('i'), NumLit(10)),
597 Block([
598 Assign('total', BinOp('+', Var('total'), Var('i'))),
599 Assign('i', BinOp('+', Var('i'), NumLit(1))),
600 ])
601 ),
602 PrintStmt(Var('total')),
603 ])
604 compiler = Compiler()
605 compiler.compile_program(prog)
606 print("Bytecode:")
607 disassemble(compiler.functions['__main__'])
608 vm = VM(compiler.functions)
609 vm.run()
610
611
612def demo4_function_call():
613 """Compile and run: recursive factorial function"""
614 print("--- Demo 4: Recursive Factorial ---")
615 # factorial(n) = if n <= 1 then 1 else n * factorial(n-1)
616 factorial_def = FuncDef(
617 name='factorial',
618 params=['n'],
619 body=Block([
620 IfStmt(
621 BinOp('<=', Var('n'), NumLit(1)),
622 Block([ReturnStmt(NumLit(1))]),
623 Block([ReturnStmt(
624 BinOp('*', Var('n'),
625 CallExpr('factorial', [BinOp('-', Var('n'), NumLit(1))]))
626 )]),
627 )
628 ])
629 )
630 prog = Program([
631 factorial_def,
632 Assign('result', CallExpr('factorial', [NumLit(6)])),
633 PrintStmt(Var('result')),
634 ])
635 compiler = Compiler()
636 compiler.compile_program(prog)
637 print("Bytecode for 'factorial':")
638 disassemble(compiler.functions['factorial'])
639 print("Bytecode for '__main__':")
640 disassemble(compiler.functions['__main__'])
641 vm = VM(compiler.functions)
642 vm.run()
643
644
645def main():
646 print("=" * 60)
647 print("Bytecode Compiler and Stack-Based VM Demo")
648 print("=" * 60)
649 print()
650 demo1_arithmetic()
651 print()
652 demo2_if_else()
653 print()
654 demo3_while_loop()
655 print()
656 demo4_function_call()
657
658
659if __name__ == "__main__":
660 main()