05_quantization.py

Download
python 423 lines 13.5 KB
  1"""
  2Foundation Models - Quantization Techniques
  3
  4Demonstrates INT8/INT4 quantization for model compression.
  5Implements symmetric and asymmetric quantization schemes.
  6Shows quantization error analysis and calibration concepts.
  7
  8Requires: PyTorch, NumPy
  9"""
 10
 11import torch
 12import numpy as np
 13
 14
 15def symmetric_quantize(tensor, num_bits=8):
 16    """
 17    Symmetric quantization: maps [-α, α] to [-2^(b-1), 2^(b-1)-1]
 18
 19    Args:
 20        tensor: Input tensor
 21        num_bits: Number of bits (8 for INT8, 4 for INT4)
 22
 23    Returns:
 24        Tuple of (quantized tensor, scale factor)
 25    """
 26    # Determine quantization range
 27    q_min = -(2 ** (num_bits - 1))
 28    q_max = 2 ** (num_bits - 1) - 1
 29
 30    # Find scale: α / 2^(b-1)
 31    alpha = tensor.abs().max()
 32    scale = alpha / q_max
 33
 34    # Quantize: round(x / scale)
 35    quantized = torch.round(tensor / scale).clamp(q_min, q_max)
 36
 37    return quantized, scale
 38
 39
 40def symmetric_dequantize(quantized, scale):
 41    """
 42    Dequantize symmetric quantized tensor.
 43
 44    Args:
 45        quantized: Quantized tensor
 46        scale: Scale factor from quantization
 47
 48    Returns:
 49        Dequantized tensor
 50    """
 51    return quantized * scale
 52
 53
 54def asymmetric_quantize(tensor, num_bits=8):
 55    """
 56    Asymmetric quantization: maps [min, max] to [0, 2^b - 1]
 57
 58    Args:
 59        tensor: Input tensor
 60        num_bits: Number of bits
 61
 62    Returns:
 63        Tuple of (quantized tensor, scale, zero_point)
 64    """
 65    # Quantization range
 66    q_min = 0
 67    q_max = 2 ** num_bits - 1
 68
 69    # Calibration: find min and max
 70    r_min = tensor.min()
 71    r_max = tensor.max()
 72
 73    # Compute scale and zero point
 74    scale = (r_max - r_min) / (q_max - q_min)
 75    zero_point = q_min - torch.round(r_min / scale)
 76    zero_point = zero_point.clamp(q_min, q_max)
 77
 78    # Quantize: round(x / scale + zero_point)
 79    quantized = torch.round(tensor / scale + zero_point).clamp(q_min, q_max)
 80
 81    return quantized, scale, zero_point
 82
 83
 84def asymmetric_dequantize(quantized, scale, zero_point):
 85    """
 86    Dequantize asymmetric quantized tensor.
 87
 88    Args:
 89        quantized: Quantized tensor
 90        scale: Scale factor
 91        zero_point: Zero point offset
 92
 93    Returns:
 94        Dequantized tensor
 95    """
 96    return (quantized - zero_point) * scale
 97
 98
 99def compute_quantization_error(original, dequantized):
100    """
101    Compute quantization error metrics.
102
103    Args:
104        original: Original tensor
105        dequantized: Dequantized tensor
106
107    Returns:
108        Dictionary of error metrics
109    """
110    error = (original - dequantized).abs()
111
112    metrics = {
113        'max_error': error.max().item(),
114        'mean_error': error.mean().item(),
115        'mse': ((original - dequantized) ** 2).mean().item(),
116        'snr_db': 10 * torch.log10((original ** 2).mean() / ((error ** 2).mean() + 1e-10)).item(),
117    }
118
119    return metrics
120
121
122def per_channel_quantize(tensor, num_bits=8, dim=0):
123    """
124    Per-channel (per-output) quantization for better accuracy.
125
126    Args:
127        tensor: Input tensor (e.g., weight matrix)
128        num_bits: Number of bits
129        dim: Channel dimension
130
131    Returns:
132        Tuple of (quantized, scales)
133    """
134    # Move channel dim to front
135    tensor = tensor.transpose(0, dim) if dim != 0 else tensor
136    out_channels = tensor.shape[0]
137
138    quantized_channels = []
139    scales = []
140
141    q_min = -(2 ** (num_bits - 1))
142    q_max = 2 ** (num_bits - 1) - 1
143
144    # Quantize each channel independently
145    for i in range(out_channels):
146        channel = tensor[i]
147        alpha = channel.abs().max()
148        scale = alpha / q_max
149
150        quant_channel = torch.round(channel / scale).clamp(q_min, q_max)
151        quantized_channels.append(quant_channel)
152        scales.append(scale)
153
154    quantized = torch.stack(quantized_channels, dim=0)
155
156    # Transpose back
157    quantized = quantized.transpose(0, dim) if dim != 0 else quantized
158
159    return quantized, torch.tensor(scales)
160
161
162def per_channel_dequantize(quantized, scales, dim=0):
163    """Dequantize per-channel quantized tensor."""
164    quantized = quantized.transpose(0, dim) if dim != 0 else quantized
165    out_channels = quantized.shape[0]
166
167    dequantized_channels = []
168    for i in range(out_channels):
169        channel = quantized[i] * scales[i]
170        dequantized_channels.append(channel)
171
172    dequantized = torch.stack(dequantized_channels, dim=0)
173    dequantized = dequantized.transpose(0, dim) if dim != 0 else dequantized
174
175    return dequantized
176
177
178# ============================================================
179# Demonstrations
180# ============================================================
181
182def demo_symmetric_quantization():
183    """Demonstrate symmetric quantization."""
184    print("=" * 60)
185    print("DEMO 1: Symmetric Quantization")
186    print("=" * 60)
187
188    # Create random tensor
189    torch.manual_seed(42)
190    tensor = torch.randn(1000) * 10
191
192    print(f"\nOriginal tensor:")
193    print(f"  Shape: {tensor.shape}")
194    print(f"  Range: [{tensor.min():.2f}, {tensor.max():.2f}]")
195    print(f"  Mean: {tensor.mean():.2f}, Std: {tensor.std():.2f}")
196
197    # Quantize to INT8 and INT4
198    for num_bits in [8, 4]:
199        quantized, scale = symmetric_quantize(tensor, num_bits=num_bits)
200        dequantized = symmetric_dequantize(quantized, scale)
201
202        metrics = compute_quantization_error(tensor, dequantized)
203
204        print(f"\nINT{num_bits} Quantization:")
205        print(f"  Scale: {scale:.6f}")
206        print(f"  Quantized range: [{quantized.min():.0f}, {quantized.max():.0f}]")
207        print(f"  Max error: {metrics['max_error']:.4f}")
208        print(f"  Mean error: {metrics['mean_error']:.4f}")
209        print(f"  SNR: {metrics['snr_db']:.2f} dB")
210
211
212def demo_asymmetric_quantization():
213    """Demonstrate asymmetric quantization."""
214    print("\n" + "=" * 60)
215    print("DEMO 2: Asymmetric Quantization")
216    print("=" * 60)
217
218    # Create tensor with asymmetric range (e.g., activations)
219    torch.manual_seed(42)
220    tensor = torch.relu(torch.randn(1000) * 5 + 2)  # Mostly positive
221
222    print(f"\nOriginal tensor (ReLU activations):")
223    print(f"  Range: [{tensor.min():.2f}, {tensor.max():.2f}]")
224    print(f"  Mean: {tensor.mean():.2f}, Std: {tensor.std():.2f}")
225
226    # Compare symmetric vs asymmetric
227    for method in ['symmetric', 'asymmetric']:
228        if method == 'symmetric':
229            quantized, scale = symmetric_quantize(tensor, num_bits=8)
230            dequantized = symmetric_dequantize(quantized, scale)
231            print(f"\nSymmetric INT8:")
232            print(f"  Scale: {scale:.6f}")
233        else:
234            quantized, scale, zero_point = asymmetric_quantize(tensor, num_bits=8)
235            dequantized = asymmetric_dequantize(quantized, scale, zero_point)
236            print(f"\nAsymmetric INT8:")
237            print(f"  Scale: {scale:.6f}, Zero point: {zero_point:.0f}")
238
239        metrics = compute_quantization_error(tensor, dequantized)
240        print(f"  Max error: {metrics['max_error']:.4f}")
241        print(f"  Mean error: {metrics['mean_error']:.4f}")
242        print(f"  SNR: {metrics['snr_db']:.2f} dB")
243
244
245def demo_per_channel_quantization():
246    """Demonstrate per-channel quantization for weight matrices."""
247    print("\n" + "=" * 60)
248    print("DEMO 3: Per-Channel Quantization")
249    print("=" * 60)
250
251    # Simulate weight matrix with different channel statistics
252    torch.manual_seed(42)
253    weight = torch.randn(256, 512)
254
255    # Make some channels have larger magnitude
256    weight[:64] *= 10
257    weight[64:128] *= 0.1
258
259    print(f"\nWeight matrix:")
260    print(f"  Shape: {weight.shape}")
261    print(f"  Overall range: [{weight.min():.2f}, {weight.max():.2f}]")
262    print(f"  Channel 0 range: [{weight[0].min():.2f}, {weight[0].max():.2f}]")
263    print(f"  Channel 100 range: [{weight[100].min():.2f}, {weight[100].max():.2f}]")
264
265    # Per-tensor quantization
266    quant_tensor, scale_tensor = symmetric_quantize(weight, num_bits=8)
267    dequant_tensor = symmetric_dequantize(quant_tensor, scale_tensor)
268    metrics_tensor = compute_quantization_error(weight, dequant_tensor)
269
270    print(f"\nPer-tensor quantization:")
271    print(f"  Single scale: {scale_tensor:.6f}")
272    print(f"  Max error: {metrics_tensor['max_error']:.4f}")
273    print(f"  SNR: {metrics_tensor['snr_db']:.2f} dB")
274
275    # Per-channel quantization
276    quant_channel, scales_channel = per_channel_quantize(weight, num_bits=8, dim=0)
277    dequant_channel = per_channel_dequantize(quant_channel, scales_channel, dim=0)
278    metrics_channel = compute_quantization_error(weight, dequant_channel)
279
280    print(f"\nPer-channel quantization:")
281    print(f"  Scale range: [{scales_channel.min():.6f}, {scales_channel.max():.6f}]")
282    print(f"  Max error: {metrics_channel['max_error']:.4f}")
283    print(f"  SNR: {metrics_channel['snr_db']:.2f} dB")
284    print(f"  Improvement: {metrics_channel['snr_db'] - metrics_tensor['snr_db']:.2f} dB")
285
286
287def demo_bit_depth_comparison():
288    """Compare different quantization bit depths."""
289    print("\n" + "=" * 60)
290    print("DEMO 4: Bit Depth Comparison")
291    print("=" * 60)
292
293    torch.manual_seed(42)
294    tensor = torch.randn(10000) * 5
295
296    print(f"\nOriginal tensor: {tensor.numel()} values")
297    print(f"  FP32 size: {tensor.numel() * 4 / 1024:.2f} KB")
298
299    print("\n" + "-" * 60)
300    print(f"{'Bits':<8} {'Size (KB)':<12} {'SNR (dB)':<12} {'Compression':<12}")
301    print("-" * 60)
302
303    for num_bits in [2, 4, 8, 16]:
304        quantized, scale = symmetric_quantize(tensor, num_bits=num_bits)
305        dequantized = symmetric_dequantize(quantized, scale)
306
307        metrics = compute_quantization_error(tensor, dequantized)
308
309        # Approximate size (quantized values + scale)
310        size_kb = (tensor.numel() * num_bits / 8 + 4) / 1024
311        compression = (tensor.numel() * 4) / (tensor.numel() * num_bits / 8 + 4)
312
313        print(f"{num_bits:<8} {size_kb:<12.2f} {metrics['snr_db']:<12.2f} {compression:<12.2f}x")
314
315
316def demo_calibration():
317    """Demonstrate calibration for quantization."""
318    print("\n" + "=" * 60)
319    print("DEMO 5: Calibration Strategies")
320    print("=" * 60)
321
322    # Generate data with outliers
323    torch.manual_seed(42)
324    tensor = torch.randn(1000) * 2
325    tensor[torch.randint(0, 1000, (10,))] = torch.randn(10) * 50  # Outliers
326
327    print(f"\nData statistics:")
328    print(f"  Min: {tensor.min():.2f}, Max: {tensor.max():.2f}")
329    print(f"  Mean: {tensor.mean():.2f}, Std: {tensor.std():.2f}")
330    print(f"  99th percentile: {torch.quantile(tensor.abs(), 0.99):.2f}")
331
332    # Strategy 1: Min-max (uses full range)
333    quant1, scale1 = symmetric_quantize(tensor, num_bits=8)
334    dequant1 = symmetric_dequantize(quant1, scale1)
335    metrics1 = compute_quantization_error(tensor, dequant1)
336
337    print(f"\nMin-max calibration:")
338    print(f"  Clipping range: ±{tensor.abs().max():.2f}")
339    print(f"  SNR: {metrics1['snr_db']:.2f} dB")
340
341    # Strategy 2: Percentile-based (clip outliers)
342    percentile = 99
343    clip_value = torch.quantile(tensor.abs(), percentile / 100)
344    tensor_clipped = tensor.clamp(-clip_value, clip_value)
345
346    quant2, scale2 = symmetric_quantize(tensor_clipped, num_bits=8)
347    dequant2 = symmetric_dequantize(quant2, scale2)
348    metrics2 = compute_quantization_error(tensor_clipped, dequant2)
349
350    print(f"\n{percentile}th percentile calibration:")
351    print(f"  Clipping range: ±{clip_value:.2f}")
352    print(f"  SNR: {metrics2['snr_db']:.2f} dB")
353    print(f"  Values clipped: {(tensor.abs() > clip_value).sum().item()}")
354
355
356def demo_quantized_matmul():
357    """Demonstrate quantized matrix multiplication."""
358    print("\n" + "=" * 60)
359    print("DEMO 6: Quantized Matrix Multiplication")
360    print("=" * 60)
361
362    # Create matrices
363    torch.manual_seed(42)
364    A = torch.randn(64, 128) * 2
365    B = torch.randn(128, 256) * 3
366
367    # FP32 matmul
368    C_fp32 = torch.matmul(A, B)
369
370    print(f"\nMatrix shapes: A {A.shape} @ B {B.shape} = C {C_fp32.shape}")
371
372    # Quantize matrices
373    A_quant, A_scale = symmetric_quantize(A, num_bits=8)
374    B_quant, B_scale = symmetric_quantize(B, num_bits=8)
375
376    # Quantized matmul (in practice, use INT8 kernel)
377    # Here we dequantize for demonstration
378    A_dequant = symmetric_dequantize(A_quant, A_scale)
379    B_dequant = symmetric_dequantize(B_quant, B_scale)
380    C_quant = torch.matmul(A_dequant, B_dequant)
381
382    # Compare results
383    error = (C_fp32 - C_quant).abs()
384
385    print(f"\nResults:")
386    print(f"  FP32 result range: [{C_fp32.min():.2f}, {C_fp32.max():.2f}]")
387    print(f"  INT8 result range: [{C_quant.min():.2f}, {C_quant.max():.2f}]")
388    print(f"  Max error: {error.max():.4f}")
389    print(f"  Mean error: {error.mean():.4f}")
390    print(f"  Relative error: {(error / C_fp32.abs()).mean():.4%}")
391
392    # Memory savings
393    fp32_size = (A.numel() + B.numel() + C_fp32.numel()) * 4
394    int8_size = (A.numel() + B.numel()) * 1 + C_fp32.numel() * 4 + 8  # scales
395    print(f"\nMemory usage:")
396    print(f"  FP32: {fp32_size / 1024:.2f} KB")
397    print(f"  INT8: {int8_size / 1024:.2f} KB")
398    print(f"  Savings: {(1 - int8_size/fp32_size) * 100:.1f}%")
399
400
401if __name__ == "__main__":
402    print("\n" + "=" * 60)
403    print("Foundation Models: Quantization")
404    print("=" * 60)
405
406    demo_symmetric_quantization()
407    demo_asymmetric_quantization()
408    demo_per_channel_quantization()
409    demo_bit_depth_comparison()
410    demo_calibration()
411    demo_quantized_matmul()
412
413    print("\n" + "=" * 60)
414    print("Key Takeaways:")
415    print("=" * 60)
416    print("1. Quantization maps FP32 → INT8/INT4 for compression")
417    print("2. Symmetric: best for weights (centered at 0)")
418    print("3. Asymmetric: best for activations (arbitrary range)")
419    print("4. Per-channel: better accuracy for heterogeneous data")
420    print("5. Calibration: clip outliers to reduce quantization error")
421    print("6. INT8 provides ~4x compression with minimal accuracy loss")
422    print("=" * 60)