1"""
209_optimizer.py - Local Optimizations on Three-Address Code
3
4Demonstrates classic local (basic-block-level) optimizations:
5
6 1. Constant Folding
7 Evaluate expressions with all-constant operands at compile time.
8 e.g., t1 = 3 + 4 --> t1 = 7
9
10 2. Constant Propagation
11 Replace uses of a variable with its known constant value.
12 e.g., x = 5; t1 = x + 2 --> x = 5; t1 = 5 + 2 --> t1 = 7
13
14 3. Algebraic Simplification
15 Apply algebraic identities to simplify expressions.
16 e.g., t = x * 1 --> t = x
17 t = x + 0 --> t = x
18 t = x * 0 --> t = 0
19 t = x - x --> t = 0
20
21 4. Common Subexpression Elimination (CSE)
22 Avoid recomputing the same expression twice within a block.
23 e.g., t1 = a + b; t2 = a + b --> t1 = a + b; t2 = t1
24
25 5. Dead Code Elimination
26 Remove assignments to temporaries that are never used.
27 e.g., t1 = 5 (never read) --> (removed)
28
29Topics covered:
30 - Value numbering for CSE
31 - Dataflow information within a basic block
32 - Iterative optimization (apply passes until no changes)
33 - Counting how many optimizations each pass performed
34"""
35
36from __future__ import annotations
37from dataclasses import dataclass, field
38from typing import Any, Optional, Union
39
40
41# ---------------------------------------------------------------------------
42# TAC instruction types (simplified from 08_three_address_code.py)
43# ---------------------------------------------------------------------------
44
45@dataclass
46class Assign:
47 dest: str
48 src: Any # literal value or variable name
49
50 def __str__(self):
51 return f" {self.dest} = {self.src}"
52
53 def copy(self):
54 return Assign(self.dest, self.src)
55
56
57@dataclass
58class BinOp:
59 dest: str
60 left: Any
61 op: str
62 right: Any
63
64 def __str__(self):
65 return f" {self.dest} = {self.left} {self.op} {self.right}"
66
67 def copy(self):
68 return BinOp(self.dest, self.left, self.op, self.right)
69
70
71@dataclass
72class UnaryOp:
73 dest: str
74 op: str
75 src: Any
76
77 def __str__(self):
78 return f" {self.dest} = {self.op}{self.src}"
79
80 def copy(self):
81 return UnaryOp(self.dest, self.op, self.src)
82
83
84@dataclass
85class Label:
86 name: str
87
88 def __str__(self):
89 return f"{self.name}:"
90
91 def copy(self):
92 return Label(self.name)
93
94
95@dataclass
96class Jump:
97 target: str
98
99 def __str__(self):
100 return f" goto {self.target}"
101
102 def copy(self):
103 return Jump(self.target)
104
105
106@dataclass
107class CondJump:
108 condition: Any
109 true_target: str
110 false_target: Optional[str] = None
111
112 def __str__(self):
113 s = f" if {self.condition} goto {self.true_target}"
114 if self.false_target:
115 s += f" else goto {self.false_target}"
116 return s
117
118 def copy(self):
119 return CondJump(self.condition, self.true_target, self.false_target)
120
121
122@dataclass
123class Return:
124 value: Optional[Any] = None
125
126 def __str__(self):
127 return f" return {self.value}" if self.value is not None else " return"
128
129 def copy(self):
130 return Return(self.value)
131
132
133TACInstr = Union[Assign, BinOp, UnaryOp, Label, Jump, CondJump, Return]
134
135
136# ---------------------------------------------------------------------------
137# Utility: is a value a compile-time constant?
138# ---------------------------------------------------------------------------
139
140def is_const(v: Any) -> bool:
141 return isinstance(v, (int, float, bool)) or (isinstance(v, str) and v.lstrip('-').replace('.', '', 1).isdigit())
142
143
144def to_num(v: Any) -> Union[int, float]:
145 """Convert a value to a number for constant folding."""
146 if isinstance(v, (int, float)):
147 return v
148 s = str(v)
149 try:
150 return int(s)
151 except ValueError:
152 return float(s)
153
154
155def is_number(v: Any) -> bool:
156 try:
157 to_num(v)
158 return True
159 except (ValueError, TypeError):
160 return False
161
162
163# ---------------------------------------------------------------------------
164# Optimization Pass Base
165# ---------------------------------------------------------------------------
166
167class OptPass:
168 """Base class for optimization passes."""
169 name: str = "unnamed"
170
171 def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
172 """
173 Run the pass on a list of instructions.
174 Returns (new_instrs, num_changes).
175 """
176 raise NotImplementedError
177
178
179# ---------------------------------------------------------------------------
180# Pass 1: Constant Folding
181# ---------------------------------------------------------------------------
182
183def fold_binop(op: str, l: Any, r: Any) -> Optional[Any]:
184 """
185 Try to fold a binary operation on two constant operands.
186 Returns the folded value or None if not foldable.
187 """
188 if not (is_number(l) and is_number(r)):
189 return None
190 lv, rv = to_num(l), to_num(r)
191 try:
192 match op:
193 case '+': return int(lv + rv) if isinstance(lv, int) and isinstance(rv, int) else lv + rv
194 case '-': return int(lv - rv) if isinstance(lv, int) and isinstance(rv, int) else lv - rv
195 case '*': return int(lv * rv) if isinstance(lv, int) and isinstance(rv, int) else lv * rv
196 case '/':
197 if rv == 0: return None
198 return lv // rv if isinstance(lv, int) and isinstance(rv, int) else lv / rv
199 case '%': return int(lv) % int(rv) if rv != 0 else None
200 case '<': return int(lv < rv)
201 case '>': return int(lv > rv)
202 case '<=': return int(lv <= rv)
203 case '>=': return int(lv >= rv)
204 case '==': return int(lv == rv)
205 case '!=': return int(lv != rv)
206 except Exception:
207 pass
208 return None
209
210
211class ConstantFolding(OptPass):
212 name = "Constant Folding"
213
214 def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
215 new_instrs = []
216 changes = 0
217 for instr in instrs:
218 if isinstance(instr, BinOp):
219 folded = fold_binop(instr.op, instr.left, instr.right)
220 if folded is not None:
221 new_instrs.append(Assign(instr.dest, folded))
222 changes += 1
223 continue
224 elif isinstance(instr, UnaryOp):
225 if instr.op == '-' and is_number(instr.src):
226 new_instrs.append(Assign(instr.dest, -to_num(instr.src)))
227 changes += 1
228 continue
229 if instr.op == '!' and is_number(instr.src):
230 new_instrs.append(Assign(instr.dest, int(not to_num(instr.src))))
231 changes += 1
232 continue
233 new_instrs.append(instr)
234 return new_instrs, changes
235
236
237# ---------------------------------------------------------------------------
238# Pass 2: Constant Propagation
239# ---------------------------------------------------------------------------
240
241class ConstantPropagation(OptPass):
242 name = "Constant Propagation"
243
244 def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
245 """
246 Forward analysis: maintain a map {var -> constant_value}.
247 Replace variable references with their known constant values.
248 Invalidate a variable's constant when it is reassigned to a non-constant.
249 """
250 const_map: dict[str, Any] = {}
251 new_instrs = []
252 changes = 0
253
254 def subst(v: Any) -> Any:
255 if isinstance(v, str) and v in const_map:
256 return const_map[v]
257 return v
258
259 for instr in instrs:
260 if isinstance(instr, Assign):
261 new_src = subst(instr.src)
262 if new_src != instr.src:
263 changes += 1
264 new_instrs.append(Assign(instr.dest, new_src))
265 # Update const map
266 if is_number(new_src):
267 const_map[instr.dest] = new_src
268 else:
269 const_map.pop(instr.dest, None)
270
271 elif isinstance(instr, BinOp):
272 nl = subst(instr.left)
273 nr = subst(instr.right)
274 if nl != instr.left or nr != instr.right:
275 changes += 1
276 new_instrs.append(BinOp(instr.dest, nl, instr.op, nr))
277 const_map.pop(instr.dest, None)
278
279 elif isinstance(instr, UnaryOp):
280 ns = subst(instr.src)
281 if ns != instr.src:
282 changes += 1
283 new_instrs.append(UnaryOp(instr.dest, instr.op, ns))
284 const_map.pop(instr.dest, None)
285
286 elif isinstance(instr, CondJump):
287 nc = subst(instr.condition)
288 if nc != instr.condition:
289 changes += 1
290 new_instrs.append(CondJump(nc, instr.true_target, instr.false_target))
291
292 elif isinstance(instr, Return):
293 nv = subst(instr.value) if instr.value is not None else None
294 if nv != instr.value:
295 changes += 1
296 new_instrs.append(Return(nv))
297
298 else:
299 # Label, Jump: don't invalidate anything
300 new_instrs.append(instr)
301
302 return new_instrs, changes
303
304
305# ---------------------------------------------------------------------------
306# Pass 3: Algebraic Simplification
307# ---------------------------------------------------------------------------
308
309class AlgebraicSimplification(OptPass):
310 name = "Algebraic Simplification"
311
312 def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
313 new_instrs = []
314 changes = 0
315 for instr in instrs:
316 if isinstance(instr, BinOp):
317 simplified = self._simplify(instr)
318 if simplified is not instr:
319 new_instrs.append(simplified)
320 changes += 1
321 continue
322 new_instrs.append(instr)
323 return new_instrs, changes
324
325 def _simplify(self, b: BinOp) -> TACInstr:
326 l, op, r = b.left, b.op, b.right
327 # x + 0 = x, 0 + x = x
328 if op == '+' and r == 0: return Assign(b.dest, l)
329 if op == '+' and l == 0: return Assign(b.dest, r)
330 # x - 0 = x
331 if op == '-' and r == 0: return Assign(b.dest, l)
332 # x * 1 = x, 1 * x = x
333 if op == '*' and r == 1: return Assign(b.dest, l)
334 if op == '*' and l == 1: return Assign(b.dest, r)
335 # x * 0 = 0, 0 * x = 0
336 if op == '*' and (r == 0 or l == 0): return Assign(b.dest, 0)
337 # x / 1 = x
338 if op == '/' and r == 1: return Assign(b.dest, l)
339 # x - x = 0 (only if same variable)
340 if op == '-' and l == r and isinstance(l, str): return Assign(b.dest, 0)
341 # x / x = 1 (only if same variable, ignore division by zero)
342 if op == '/' and l == r and isinstance(l, str): return Assign(b.dest, 1)
343 return b
344
345
346# ---------------------------------------------------------------------------
347# Pass 4: Common Subexpression Elimination (CSE)
348# ---------------------------------------------------------------------------
349
350class CSE(OptPass):
351 name = "Common Subexpression Elimination"
352
353 def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
354 """
355 For each BinOp/UnaryOp, check if the same expression was computed before.
356 If so, replace with the earlier result.
357 Invalidate when any operand variable is redefined.
358 """
359 # Maps (left, op, right) -> existing_temp
360 expr_map: dict[tuple, str] = {}
361 # Maps variable -> set of expression keys using it
362 var_to_exprs: dict[str, set[tuple]] = {}
363
364 new_instrs = []
365 changes = 0
366
367 def invalidate(var: str):
368 """Remove all cached expressions that use 'var'."""
369 for key in list(var_to_exprs.get(var, set())):
370 expr_map.pop(key, None)
371 var_to_exprs.pop(var, None)
372
373 def record(key: tuple, dest: str, operands: list[str]):
374 expr_map[key] = dest
375 for op in operands:
376 if isinstance(op, str):
377 var_to_exprs.setdefault(op, set()).add(key)
378
379 for instr in instrs:
380 if isinstance(instr, BinOp):
381 key = (instr.left, instr.op, instr.right)
382 if key in expr_map:
383 # Replace with copy from earlier result
384 new_instrs.append(Assign(instr.dest, expr_map[key]))
385 changes += 1
386 # The dest is still being defined; invalidate it
387 invalidate(instr.dest)
388 # Record new alias: dest -> same expr
389 record(key, expr_map[key], [instr.left, instr.right])
390 else:
391 new_instrs.append(instr)
392 invalidate(instr.dest)
393 record(key, instr.dest, [instr.left, instr.right])
394
395 elif isinstance(instr, UnaryOp):
396 key = (instr.op, instr.src)
397 if key in expr_map:
398 new_instrs.append(Assign(instr.dest, expr_map[key]))
399 changes += 1
400 invalidate(instr.dest)
401 else:
402 new_instrs.append(instr)
403 invalidate(instr.dest)
404 record(key, instr.dest, [instr.src])
405
406 elif isinstance(instr, Assign):
407 # Redefining dest: invalidate cached expressions that use dest
408 invalidate(instr.dest)
409 new_instrs.append(instr)
410
411 elif isinstance(instr, Label):
412 # At a label (block boundary), clear all cached info
413 expr_map.clear()
414 var_to_exprs.clear()
415 new_instrs.append(instr)
416
417 else:
418 new_instrs.append(instr)
419
420 return new_instrs, changes
421
422
423# ---------------------------------------------------------------------------
424# Pass 5: Dead Code Elimination
425# ---------------------------------------------------------------------------
426
427class DeadCodeElimination(OptPass):
428 name = "Dead Code Elimination"
429
430 def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
431 """
432 Remove instructions that assign to variables/temps that are never
433 subsequently used. Uses a backward liveness analysis.
434
435 A variable is 'live' at a point if it may be used after that point.
436 An assignment 'dest = ...' is dead if dest is not live after the assignment.
437 """
438 # Collect all 'uses' in the instruction list
439 def uses_of(instr: TACInstr) -> set[str]:
440 u: set[str] = set()
441 if isinstance(instr, Assign):
442 if isinstance(instr.src, str): u.add(instr.src)
443 elif isinstance(instr, BinOp):
444 if isinstance(instr.left, str): u.add(instr.left)
445 if isinstance(instr.right, str): u.add(instr.right)
446 elif isinstance(instr, UnaryOp):
447 if isinstance(instr.src, str): u.add(instr.src)
448 elif isinstance(instr, CondJump):
449 if isinstance(instr.condition, str): u.add(instr.condition)
450 elif isinstance(instr, Return):
451 if isinstance(instr.value, str): u.add(instr.value)
452 return u
453
454 def def_of(instr: TACInstr) -> Optional[str]:
455 if isinstance(instr, (Assign, BinOp, UnaryOp)):
456 return instr.dest
457 return None
458
459 # Backward pass: compute live variables at each point
460 # live[i] = set of variables live AFTER instruction i
461 n = len(instrs)
462 live: list[set[str]] = [set() for _ in range(n + 1)]
463
464 for i in range(n - 1, -1, -1):
465 live[i] = set(live[i + 1])
466 live[i].update(uses_of(instrs[i]))
467 d = def_of(instrs[i])
468 if d is not None:
469 live[i].discard(d)
470
471 # Eliminate dead assignments
472 new_instrs = []
473 changes = 0
474 for i, instr in enumerate(instrs):
475 d = def_of(instr)
476 # Only eliminate temporaries (t0, t1, ...), not user variables
477 if d is not None and d.startswith('t') and d[1:].isdigit():
478 if d not in live[i + 1]:
479 changes += 1
480 continue # skip this dead instruction
481 new_instrs.append(instr)
482
483 return new_instrs, changes
484
485
486# ---------------------------------------------------------------------------
487# Optimization pipeline
488# ---------------------------------------------------------------------------
489
490def run_optimizer(instrs: list[TACInstr], max_passes: int = 10) -> list[TACInstr]:
491 """
492 Run all optimization passes iteratively until no changes occur.
493 """
494 passes = [
495 ConstantPropagation(),
496 ConstantFolding(),
497 AlgebraicSimplification(),
498 CSE(),
499 DeadCodeElimination(),
500 ]
501
502 print("\nOptimization log:")
503 for iteration in range(max_passes):
504 total_changes = 0
505 for p in passes:
506 instrs, n = p.run(instrs)
507 if n:
508 print(f" Pass {iteration+1} [{p.name}]: {n} change(s)")
509 total_changes += n
510 if total_changes == 0:
511 print(f" Fixed point reached after {iteration+1} iteration(s).")
512 break
513
514 return instrs
515
516
517# ---------------------------------------------------------------------------
518# Demo
519# ---------------------------------------------------------------------------
520
521def make_sample_tac() -> list[TACInstr]:
522 """
523 TAC for: t0 = (3 + 4) * (x * 1 - 0); t1 = (3 + 4); t2 = t1 * (x * 1 - 0)
524 Contains several optimization opportunities:
525 - 3+4 -> 7 (constant folding)
526 - x*1 -> x (algebraic simplification)
527 - x-0 -> x (algebraic simplification)
528 - second 3+4 -> t0 copy is dead (CSE + DCE)
529 """
530 return [
531 BinOp('t0', 3, '+', 4), # t0 = 3 + 4 (constant fold -> 7)
532 BinOp('t1', 'x', '*', 1), # t1 = x * 1 (algebraic -> x)
533 BinOp('t2', 't1', '-', 0), # t2 = t1 - 0 (algebraic -> t1 -> x)
534 BinOp('t3', 't0', '*', 't2'), # t3 = t0 * t2 -> 7 * x
535 # CSE: same as first computation of 3+4
536 BinOp('t4', 3, '+', 4), # t4 = 3 + 4 (CSE -> t4 = t0)
537 BinOp('t5', 't4', '*', 't2'), # t5 = t4 * t2 (t4 is dead if t5 == t3 by CSE)
538 # Dead code: t6 assigned but never used
539 BinOp('t6', 'a', '+', 'b'), # dead if t6 never read
540 Return('t3'),
541 ]
542
543
544def print_tac(label: str, instrs: list[TACInstr]) -> None:
545 print(f"\n{label}:")
546 for i in instrs:
547 print(i)
548
549
550def main():
551 print("=" * 60)
552 print("TAC Optimizer Demo")
553 print("=" * 60)
554
555 original = make_sample_tac()
556 print_tac("Original TAC", original)
557
558 optimized = run_optimizer(list(original))
559 print_tac("Optimized TAC", optimized)
560
561 print(f"\nReduction: {len(original)} -> {len(optimized)} instructions")
562
563 # Another example: constant propagation chain
564 print("\n" + "=" * 60)
565 print("Example 2: Constant propagation chain")
566 tac2 = [
567 Assign('a', 5),
568 Assign('b', 3),
569 BinOp('t0', 'a', '+', 'b'), # -> 5+3=8
570 BinOp('t1', 't0', '*', 2), # -> 8*2=16
571 BinOp('t2', 't1', '-', 1), # -> 16-1=15
572 CondJump('t2', 'L_true', 'L_false'),
573 Label('L_true'),
574 Return('t2'),
575 Label('L_false'),
576 Return(0),
577 ]
578 print_tac("Original", tac2)
579 opt2 = run_optimizer(list(tac2))
580 print_tac("Optimized", opt2)
581
582
583if __name__ == "__main__":
584 main()