09_optimizer.py

Download
python 585 lines 19.3 KB
  1"""
  209_optimizer.py - Local Optimizations on Three-Address Code
  3
  4Demonstrates classic local (basic-block-level) optimizations:
  5
  6  1. Constant Folding
  7     Evaluate expressions with all-constant operands at compile time.
  8     e.g.,  t1 = 3 + 4  -->  t1 = 7
  9
 10  2. Constant Propagation
 11     Replace uses of a variable with its known constant value.
 12     e.g.,  x = 5; t1 = x + 2  -->  x = 5; t1 = 5 + 2 --> t1 = 7
 13
 14  3. Algebraic Simplification
 15     Apply algebraic identities to simplify expressions.
 16     e.g.,  t = x * 1  -->  t = x
 17            t = x + 0  -->  t = x
 18            t = x * 0  -->  t = 0
 19            t = x - x  -->  t = 0
 20
 21  4. Common Subexpression Elimination (CSE)
 22     Avoid recomputing the same expression twice within a block.
 23     e.g.,  t1 = a + b; t2 = a + b  -->  t1 = a + b; t2 = t1
 24
 25  5. Dead Code Elimination
 26     Remove assignments to temporaries that are never used.
 27     e.g.,  t1 = 5 (never read)  -->  (removed)
 28
 29Topics covered:
 30  - Value numbering for CSE
 31  - Dataflow information within a basic block
 32  - Iterative optimization (apply passes until no changes)
 33  - Counting how many optimizations each pass performed
 34"""
 35
 36from __future__ import annotations
 37from dataclasses import dataclass, field
 38from typing import Any, Optional, Union
 39
 40
 41# ---------------------------------------------------------------------------
 42# TAC instruction types (simplified from 08_three_address_code.py)
 43# ---------------------------------------------------------------------------
 44
 45@dataclass
 46class Assign:
 47    dest: str
 48    src: Any        # literal value or variable name
 49
 50    def __str__(self):
 51        return f"    {self.dest} = {self.src}"
 52
 53    def copy(self):
 54        return Assign(self.dest, self.src)
 55
 56
 57@dataclass
 58class BinOp:
 59    dest: str
 60    left: Any
 61    op: str
 62    right: Any
 63
 64    def __str__(self):
 65        return f"    {self.dest} = {self.left} {self.op} {self.right}"
 66
 67    def copy(self):
 68        return BinOp(self.dest, self.left, self.op, self.right)
 69
 70
 71@dataclass
 72class UnaryOp:
 73    dest: str
 74    op: str
 75    src: Any
 76
 77    def __str__(self):
 78        return f"    {self.dest} = {self.op}{self.src}"
 79
 80    def copy(self):
 81        return UnaryOp(self.dest, self.op, self.src)
 82
 83
 84@dataclass
 85class Label:
 86    name: str
 87
 88    def __str__(self):
 89        return f"{self.name}:"
 90
 91    def copy(self):
 92        return Label(self.name)
 93
 94
 95@dataclass
 96class Jump:
 97    target: str
 98
 99    def __str__(self):
100        return f"    goto {self.target}"
101
102    def copy(self):
103        return Jump(self.target)
104
105
106@dataclass
107class CondJump:
108    condition: Any
109    true_target: str
110    false_target: Optional[str] = None
111
112    def __str__(self):
113        s = f"    if {self.condition} goto {self.true_target}"
114        if self.false_target:
115            s += f" else goto {self.false_target}"
116        return s
117
118    def copy(self):
119        return CondJump(self.condition, self.true_target, self.false_target)
120
121
122@dataclass
123class Return:
124    value: Optional[Any] = None
125
126    def __str__(self):
127        return f"    return {self.value}" if self.value is not None else "    return"
128
129    def copy(self):
130        return Return(self.value)
131
132
133TACInstr = Union[Assign, BinOp, UnaryOp, Label, Jump, CondJump, Return]
134
135
136# ---------------------------------------------------------------------------
137# Utility: is a value a compile-time constant?
138# ---------------------------------------------------------------------------
139
140def is_const(v: Any) -> bool:
141    return isinstance(v, (int, float, bool)) or (isinstance(v, str) and v.lstrip('-').replace('.', '', 1).isdigit())
142
143
144def to_num(v: Any) -> Union[int, float]:
145    """Convert a value to a number for constant folding."""
146    if isinstance(v, (int, float)):
147        return v
148    s = str(v)
149    try:
150        return int(s)
151    except ValueError:
152        return float(s)
153
154
155def is_number(v: Any) -> bool:
156    try:
157        to_num(v)
158        return True
159    except (ValueError, TypeError):
160        return False
161
162
163# ---------------------------------------------------------------------------
164# Optimization Pass Base
165# ---------------------------------------------------------------------------
166
167class OptPass:
168    """Base class for optimization passes."""
169    name: str = "unnamed"
170
171    def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
172        """
173        Run the pass on a list of instructions.
174        Returns (new_instrs, num_changes).
175        """
176        raise NotImplementedError
177
178
179# ---------------------------------------------------------------------------
180# Pass 1: Constant Folding
181# ---------------------------------------------------------------------------
182
183def fold_binop(op: str, l: Any, r: Any) -> Optional[Any]:
184    """
185    Try to fold a binary operation on two constant operands.
186    Returns the folded value or None if not foldable.
187    """
188    if not (is_number(l) and is_number(r)):
189        return None
190    lv, rv = to_num(l), to_num(r)
191    try:
192        match op:
193            case '+':  return int(lv + rv) if isinstance(lv, int) and isinstance(rv, int) else lv + rv
194            case '-':  return int(lv - rv) if isinstance(lv, int) and isinstance(rv, int) else lv - rv
195            case '*':  return int(lv * rv) if isinstance(lv, int) and isinstance(rv, int) else lv * rv
196            case '/':
197                if rv == 0: return None
198                return lv // rv if isinstance(lv, int) and isinstance(rv, int) else lv / rv
199            case '%':  return int(lv) % int(rv) if rv != 0 else None
200            case '<':  return int(lv < rv)
201            case '>':  return int(lv > rv)
202            case '<=': return int(lv <= rv)
203            case '>=': return int(lv >= rv)
204            case '==': return int(lv == rv)
205            case '!=': return int(lv != rv)
206    except Exception:
207        pass
208    return None
209
210
211class ConstantFolding(OptPass):
212    name = "Constant Folding"
213
214    def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
215        new_instrs = []
216        changes = 0
217        for instr in instrs:
218            if isinstance(instr, BinOp):
219                folded = fold_binop(instr.op, instr.left, instr.right)
220                if folded is not None:
221                    new_instrs.append(Assign(instr.dest, folded))
222                    changes += 1
223                    continue
224            elif isinstance(instr, UnaryOp):
225                if instr.op == '-' and is_number(instr.src):
226                    new_instrs.append(Assign(instr.dest, -to_num(instr.src)))
227                    changes += 1
228                    continue
229                if instr.op == '!' and is_number(instr.src):
230                    new_instrs.append(Assign(instr.dest, int(not to_num(instr.src))))
231                    changes += 1
232                    continue
233            new_instrs.append(instr)
234        return new_instrs, changes
235
236
237# ---------------------------------------------------------------------------
238# Pass 2: Constant Propagation
239# ---------------------------------------------------------------------------
240
241class ConstantPropagation(OptPass):
242    name = "Constant Propagation"
243
244    def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
245        """
246        Forward analysis: maintain a map {var -> constant_value}.
247        Replace variable references with their known constant values.
248        Invalidate a variable's constant when it is reassigned to a non-constant.
249        """
250        const_map: dict[str, Any] = {}
251        new_instrs = []
252        changes = 0
253
254        def subst(v: Any) -> Any:
255            if isinstance(v, str) and v in const_map:
256                return const_map[v]
257            return v
258
259        for instr in instrs:
260            if isinstance(instr, Assign):
261                new_src = subst(instr.src)
262                if new_src != instr.src:
263                    changes += 1
264                new_instrs.append(Assign(instr.dest, new_src))
265                # Update const map
266                if is_number(new_src):
267                    const_map[instr.dest] = new_src
268                else:
269                    const_map.pop(instr.dest, None)
270
271            elif isinstance(instr, BinOp):
272                nl = subst(instr.left)
273                nr = subst(instr.right)
274                if nl != instr.left or nr != instr.right:
275                    changes += 1
276                new_instrs.append(BinOp(instr.dest, nl, instr.op, nr))
277                const_map.pop(instr.dest, None)
278
279            elif isinstance(instr, UnaryOp):
280                ns = subst(instr.src)
281                if ns != instr.src:
282                    changes += 1
283                new_instrs.append(UnaryOp(instr.dest, instr.op, ns))
284                const_map.pop(instr.dest, None)
285
286            elif isinstance(instr, CondJump):
287                nc = subst(instr.condition)
288                if nc != instr.condition:
289                    changes += 1
290                new_instrs.append(CondJump(nc, instr.true_target, instr.false_target))
291
292            elif isinstance(instr, Return):
293                nv = subst(instr.value) if instr.value is not None else None
294                if nv != instr.value:
295                    changes += 1
296                new_instrs.append(Return(nv))
297
298            else:
299                # Label, Jump: don't invalidate anything
300                new_instrs.append(instr)
301
302        return new_instrs, changes
303
304
305# ---------------------------------------------------------------------------
306# Pass 3: Algebraic Simplification
307# ---------------------------------------------------------------------------
308
309class AlgebraicSimplification(OptPass):
310    name = "Algebraic Simplification"
311
312    def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
313        new_instrs = []
314        changes = 0
315        for instr in instrs:
316            if isinstance(instr, BinOp):
317                simplified = self._simplify(instr)
318                if simplified is not instr:
319                    new_instrs.append(simplified)
320                    changes += 1
321                    continue
322            new_instrs.append(instr)
323        return new_instrs, changes
324
325    def _simplify(self, b: BinOp) -> TACInstr:
326        l, op, r = b.left, b.op, b.right
327        # x + 0 = x,  0 + x = x
328        if op == '+' and r == 0: return Assign(b.dest, l)
329        if op == '+' and l == 0: return Assign(b.dest, r)
330        # x - 0 = x
331        if op == '-' and r == 0: return Assign(b.dest, l)
332        # x * 1 = x,  1 * x = x
333        if op == '*' and r == 1: return Assign(b.dest, l)
334        if op == '*' and l == 1: return Assign(b.dest, r)
335        # x * 0 = 0,  0 * x = 0
336        if op == '*' and (r == 0 or l == 0): return Assign(b.dest, 0)
337        # x / 1 = x
338        if op == '/' and r == 1: return Assign(b.dest, l)
339        # x - x = 0 (only if same variable)
340        if op == '-' and l == r and isinstance(l, str): return Assign(b.dest, 0)
341        # x / x = 1 (only if same variable, ignore division by zero)
342        if op == '/' and l == r and isinstance(l, str): return Assign(b.dest, 1)
343        return b
344
345
346# ---------------------------------------------------------------------------
347# Pass 4: Common Subexpression Elimination (CSE)
348# ---------------------------------------------------------------------------
349
350class CSE(OptPass):
351    name = "Common Subexpression Elimination"
352
353    def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
354        """
355        For each BinOp/UnaryOp, check if the same expression was computed before.
356        If so, replace with the earlier result.
357        Invalidate when any operand variable is redefined.
358        """
359        # Maps (left, op, right) -> existing_temp
360        expr_map: dict[tuple, str] = {}
361        # Maps variable -> set of expression keys using it
362        var_to_exprs: dict[str, set[tuple]] = {}
363
364        new_instrs = []
365        changes = 0
366
367        def invalidate(var: str):
368            """Remove all cached expressions that use 'var'."""
369            for key in list(var_to_exprs.get(var, set())):
370                expr_map.pop(key, None)
371            var_to_exprs.pop(var, None)
372
373        def record(key: tuple, dest: str, operands: list[str]):
374            expr_map[key] = dest
375            for op in operands:
376                if isinstance(op, str):
377                    var_to_exprs.setdefault(op, set()).add(key)
378
379        for instr in instrs:
380            if isinstance(instr, BinOp):
381                key = (instr.left, instr.op, instr.right)
382                if key in expr_map:
383                    # Replace with copy from earlier result
384                    new_instrs.append(Assign(instr.dest, expr_map[key]))
385                    changes += 1
386                    # The dest is still being defined; invalidate it
387                    invalidate(instr.dest)
388                    # Record new alias: dest -> same expr
389                    record(key, expr_map[key], [instr.left, instr.right])
390                else:
391                    new_instrs.append(instr)
392                    invalidate(instr.dest)
393                    record(key, instr.dest, [instr.left, instr.right])
394
395            elif isinstance(instr, UnaryOp):
396                key = (instr.op, instr.src)
397                if key in expr_map:
398                    new_instrs.append(Assign(instr.dest, expr_map[key]))
399                    changes += 1
400                    invalidate(instr.dest)
401                else:
402                    new_instrs.append(instr)
403                    invalidate(instr.dest)
404                    record(key, instr.dest, [instr.src])
405
406            elif isinstance(instr, Assign):
407                # Redefining dest: invalidate cached expressions that use dest
408                invalidate(instr.dest)
409                new_instrs.append(instr)
410
411            elif isinstance(instr, Label):
412                # At a label (block boundary), clear all cached info
413                expr_map.clear()
414                var_to_exprs.clear()
415                new_instrs.append(instr)
416
417            else:
418                new_instrs.append(instr)
419
420        return new_instrs, changes
421
422
423# ---------------------------------------------------------------------------
424# Pass 5: Dead Code Elimination
425# ---------------------------------------------------------------------------
426
427class DeadCodeElimination(OptPass):
428    name = "Dead Code Elimination"
429
430    def run(self, instrs: list[TACInstr]) -> tuple[list[TACInstr], int]:
431        """
432        Remove instructions that assign to variables/temps that are never
433        subsequently used. Uses a backward liveness analysis.
434
435        A variable is 'live' at a point if it may be used after that point.
436        An assignment 'dest = ...' is dead if dest is not live after the assignment.
437        """
438        # Collect all 'uses' in the instruction list
439        def uses_of(instr: TACInstr) -> set[str]:
440            u: set[str] = set()
441            if isinstance(instr, Assign):
442                if isinstance(instr.src, str): u.add(instr.src)
443            elif isinstance(instr, BinOp):
444                if isinstance(instr.left, str):  u.add(instr.left)
445                if isinstance(instr.right, str): u.add(instr.right)
446            elif isinstance(instr, UnaryOp):
447                if isinstance(instr.src, str): u.add(instr.src)
448            elif isinstance(instr, CondJump):
449                if isinstance(instr.condition, str): u.add(instr.condition)
450            elif isinstance(instr, Return):
451                if isinstance(instr.value, str): u.add(instr.value)
452            return u
453
454        def def_of(instr: TACInstr) -> Optional[str]:
455            if isinstance(instr, (Assign, BinOp, UnaryOp)):
456                return instr.dest
457            return None
458
459        # Backward pass: compute live variables at each point
460        # live[i] = set of variables live AFTER instruction i
461        n = len(instrs)
462        live: list[set[str]] = [set() for _ in range(n + 1)]
463
464        for i in range(n - 1, -1, -1):
465            live[i] = set(live[i + 1])
466            live[i].update(uses_of(instrs[i]))
467            d = def_of(instrs[i])
468            if d is not None:
469                live[i].discard(d)
470
471        # Eliminate dead assignments
472        new_instrs = []
473        changes = 0
474        for i, instr in enumerate(instrs):
475            d = def_of(instr)
476            # Only eliminate temporaries (t0, t1, ...), not user variables
477            if d is not None and d.startswith('t') and d[1:].isdigit():
478                if d not in live[i + 1]:
479                    changes += 1
480                    continue   # skip this dead instruction
481            new_instrs.append(instr)
482
483        return new_instrs, changes
484
485
486# ---------------------------------------------------------------------------
487# Optimization pipeline
488# ---------------------------------------------------------------------------
489
490def run_optimizer(instrs: list[TACInstr], max_passes: int = 10) -> list[TACInstr]:
491    """
492    Run all optimization passes iteratively until no changes occur.
493    """
494    passes = [
495        ConstantPropagation(),
496        ConstantFolding(),
497        AlgebraicSimplification(),
498        CSE(),
499        DeadCodeElimination(),
500    ]
501
502    print("\nOptimization log:")
503    for iteration in range(max_passes):
504        total_changes = 0
505        for p in passes:
506            instrs, n = p.run(instrs)
507            if n:
508                print(f"  Pass {iteration+1} [{p.name}]: {n} change(s)")
509            total_changes += n
510        if total_changes == 0:
511            print(f"  Fixed point reached after {iteration+1} iteration(s).")
512            break
513
514    return instrs
515
516
517# ---------------------------------------------------------------------------
518# Demo
519# ---------------------------------------------------------------------------
520
521def make_sample_tac() -> list[TACInstr]:
522    """
523    TAC for: t0 = (3 + 4) * (x * 1 - 0); t1 = (3 + 4); t2 = t1 * (x * 1 - 0)
524    Contains several optimization opportunities:
525      - 3+4 -> 7 (constant folding)
526      - x*1 -> x (algebraic simplification)
527      - x-0 -> x (algebraic simplification)
528      - second 3+4 -> t0 copy is dead (CSE + DCE)
529    """
530    return [
531        BinOp('t0', 3, '+', 4),           # t0 = 3 + 4  (constant fold -> 7)
532        BinOp('t1', 'x', '*', 1),          # t1 = x * 1  (algebraic -> x)
533        BinOp('t2', 't1', '-', 0),         # t2 = t1 - 0 (algebraic -> t1 -> x)
534        BinOp('t3', 't0', '*', 't2'),      # t3 = t0 * t2 -> 7 * x
535        # CSE: same as first computation of 3+4
536        BinOp('t4', 3, '+', 4),            # t4 = 3 + 4  (CSE -> t4 = t0)
537        BinOp('t5', 't4', '*', 't2'),      # t5 = t4 * t2 (t4 is dead if t5 == t3 by CSE)
538        # Dead code: t6 assigned but never used
539        BinOp('t6', 'a', '+', 'b'),        # dead if t6 never read
540        Return('t3'),
541    ]
542
543
544def print_tac(label: str, instrs: list[TACInstr]) -> None:
545    print(f"\n{label}:")
546    for i in instrs:
547        print(i)
548
549
550def main():
551    print("=" * 60)
552    print("TAC Optimizer Demo")
553    print("=" * 60)
554
555    original = make_sample_tac()
556    print_tac("Original TAC", original)
557
558    optimized = run_optimizer(list(original))
559    print_tac("Optimized TAC", optimized)
560
561    print(f"\nReduction: {len(original)} -> {len(optimized)} instructions")
562
563    # Another example: constant propagation chain
564    print("\n" + "=" * 60)
565    print("Example 2: Constant propagation chain")
566    tac2 = [
567        Assign('a', 5),
568        Assign('b', 3),
569        BinOp('t0', 'a', '+', 'b'),       # -> 5+3=8
570        BinOp('t1', 't0', '*', 2),         # -> 8*2=16
571        BinOp('t2', 't1', '-', 1),         # -> 16-1=15
572        CondJump('t2', 'L_true', 'L_false'),
573        Label('L_true'),
574        Return('t2'),
575        Label('L_false'),
576        Return(0),
577    ]
578    print_tac("Original", tac2)
579    opt2 = run_optimizer(list(tac2))
580    print_tac("Optimized", opt2)
581
582
583if __name__ == "__main__":
584    main()