1"""
2Foundation Models - Quantization Techniques
3
4Demonstrates INT8/INT4 quantization for model compression.
5Implements symmetric and asymmetric quantization schemes.
6Shows quantization error analysis and calibration concepts.
7
8Requires: PyTorch, NumPy
9"""
10
11import torch
12import numpy as np
13
14
15def symmetric_quantize(tensor, num_bits=8):
16 """
17 Symmetric quantization: maps [-α, α] to [-2^(b-1), 2^(b-1)-1]
18
19 Args:
20 tensor: Input tensor
21 num_bits: Number of bits (8 for INT8, 4 for INT4)
22
23 Returns:
24 Tuple of (quantized tensor, scale factor)
25 """
26 # Determine quantization range
27 q_min = -(2 ** (num_bits - 1))
28 q_max = 2 ** (num_bits - 1) - 1
29
30 # Find scale: α / 2^(b-1)
31 alpha = tensor.abs().max()
32 scale = alpha / q_max
33
34 # Quantize: round(x / scale)
35 quantized = torch.round(tensor / scale).clamp(q_min, q_max)
36
37 return quantized, scale
38
39
40def symmetric_dequantize(quantized, scale):
41 """
42 Dequantize symmetric quantized tensor.
43
44 Args:
45 quantized: Quantized tensor
46 scale: Scale factor from quantization
47
48 Returns:
49 Dequantized tensor
50 """
51 return quantized * scale
52
53
54def asymmetric_quantize(tensor, num_bits=8):
55 """
56 Asymmetric quantization: maps [min, max] to [0, 2^b - 1]
57
58 Args:
59 tensor: Input tensor
60 num_bits: Number of bits
61
62 Returns:
63 Tuple of (quantized tensor, scale, zero_point)
64 """
65 # Quantization range
66 q_min = 0
67 q_max = 2 ** num_bits - 1
68
69 # Calibration: find min and max
70 r_min = tensor.min()
71 r_max = tensor.max()
72
73 # Compute scale and zero point
74 scale = (r_max - r_min) / (q_max - q_min)
75 zero_point = q_min - torch.round(r_min / scale)
76 zero_point = zero_point.clamp(q_min, q_max)
77
78 # Quantize: round(x / scale + zero_point)
79 quantized = torch.round(tensor / scale + zero_point).clamp(q_min, q_max)
80
81 return quantized, scale, zero_point
82
83
84def asymmetric_dequantize(quantized, scale, zero_point):
85 """
86 Dequantize asymmetric quantized tensor.
87
88 Args:
89 quantized: Quantized tensor
90 scale: Scale factor
91 zero_point: Zero point offset
92
93 Returns:
94 Dequantized tensor
95 """
96 return (quantized - zero_point) * scale
97
98
99def compute_quantization_error(original, dequantized):
100 """
101 Compute quantization error metrics.
102
103 Args:
104 original: Original tensor
105 dequantized: Dequantized tensor
106
107 Returns:
108 Dictionary of error metrics
109 """
110 error = (original - dequantized).abs()
111
112 metrics = {
113 'max_error': error.max().item(),
114 'mean_error': error.mean().item(),
115 'mse': ((original - dequantized) ** 2).mean().item(),
116 'snr_db': 10 * torch.log10((original ** 2).mean() / ((error ** 2).mean() + 1e-10)).item(),
117 }
118
119 return metrics
120
121
122def per_channel_quantize(tensor, num_bits=8, dim=0):
123 """
124 Per-channel (per-output) quantization for better accuracy.
125
126 Args:
127 tensor: Input tensor (e.g., weight matrix)
128 num_bits: Number of bits
129 dim: Channel dimension
130
131 Returns:
132 Tuple of (quantized, scales)
133 """
134 # Move channel dim to front
135 tensor = tensor.transpose(0, dim) if dim != 0 else tensor
136 out_channels = tensor.shape[0]
137
138 quantized_channels = []
139 scales = []
140
141 q_min = -(2 ** (num_bits - 1))
142 q_max = 2 ** (num_bits - 1) - 1
143
144 # Quantize each channel independently
145 for i in range(out_channels):
146 channel = tensor[i]
147 alpha = channel.abs().max()
148 scale = alpha / q_max
149
150 quant_channel = torch.round(channel / scale).clamp(q_min, q_max)
151 quantized_channels.append(quant_channel)
152 scales.append(scale)
153
154 quantized = torch.stack(quantized_channels, dim=0)
155
156 # Transpose back
157 quantized = quantized.transpose(0, dim) if dim != 0 else quantized
158
159 return quantized, torch.tensor(scales)
160
161
162def per_channel_dequantize(quantized, scales, dim=0):
163 """Dequantize per-channel quantized tensor."""
164 quantized = quantized.transpose(0, dim) if dim != 0 else quantized
165 out_channels = quantized.shape[0]
166
167 dequantized_channels = []
168 for i in range(out_channels):
169 channel = quantized[i] * scales[i]
170 dequantized_channels.append(channel)
171
172 dequantized = torch.stack(dequantized_channels, dim=0)
173 dequantized = dequantized.transpose(0, dim) if dim != 0 else dequantized
174
175 return dequantized
176
177
178# ============================================================
179# Demonstrations
180# ============================================================
181
182def demo_symmetric_quantization():
183 """Demonstrate symmetric quantization."""
184 print("=" * 60)
185 print("DEMO 1: Symmetric Quantization")
186 print("=" * 60)
187
188 # Create random tensor
189 torch.manual_seed(42)
190 tensor = torch.randn(1000) * 10
191
192 print(f"\nOriginal tensor:")
193 print(f" Shape: {tensor.shape}")
194 print(f" Range: [{tensor.min():.2f}, {tensor.max():.2f}]")
195 print(f" Mean: {tensor.mean():.2f}, Std: {tensor.std():.2f}")
196
197 # Quantize to INT8 and INT4
198 for num_bits in [8, 4]:
199 quantized, scale = symmetric_quantize(tensor, num_bits=num_bits)
200 dequantized = symmetric_dequantize(quantized, scale)
201
202 metrics = compute_quantization_error(tensor, dequantized)
203
204 print(f"\nINT{num_bits} Quantization:")
205 print(f" Scale: {scale:.6f}")
206 print(f" Quantized range: [{quantized.min():.0f}, {quantized.max():.0f}]")
207 print(f" Max error: {metrics['max_error']:.4f}")
208 print(f" Mean error: {metrics['mean_error']:.4f}")
209 print(f" SNR: {metrics['snr_db']:.2f} dB")
210
211
212def demo_asymmetric_quantization():
213 """Demonstrate asymmetric quantization."""
214 print("\n" + "=" * 60)
215 print("DEMO 2: Asymmetric Quantization")
216 print("=" * 60)
217
218 # Create tensor with asymmetric range (e.g., activations)
219 torch.manual_seed(42)
220 tensor = torch.relu(torch.randn(1000) * 5 + 2) # Mostly positive
221
222 print(f"\nOriginal tensor (ReLU activations):")
223 print(f" Range: [{tensor.min():.2f}, {tensor.max():.2f}]")
224 print(f" Mean: {tensor.mean():.2f}, Std: {tensor.std():.2f}")
225
226 # Compare symmetric vs asymmetric
227 for method in ['symmetric', 'asymmetric']:
228 if method == 'symmetric':
229 quantized, scale = symmetric_quantize(tensor, num_bits=8)
230 dequantized = symmetric_dequantize(quantized, scale)
231 print(f"\nSymmetric INT8:")
232 print(f" Scale: {scale:.6f}")
233 else:
234 quantized, scale, zero_point = asymmetric_quantize(tensor, num_bits=8)
235 dequantized = asymmetric_dequantize(quantized, scale, zero_point)
236 print(f"\nAsymmetric INT8:")
237 print(f" Scale: {scale:.6f}, Zero point: {zero_point:.0f}")
238
239 metrics = compute_quantization_error(tensor, dequantized)
240 print(f" Max error: {metrics['max_error']:.4f}")
241 print(f" Mean error: {metrics['mean_error']:.4f}")
242 print(f" SNR: {metrics['snr_db']:.2f} dB")
243
244
245def demo_per_channel_quantization():
246 """Demonstrate per-channel quantization for weight matrices."""
247 print("\n" + "=" * 60)
248 print("DEMO 3: Per-Channel Quantization")
249 print("=" * 60)
250
251 # Simulate weight matrix with different channel statistics
252 torch.manual_seed(42)
253 weight = torch.randn(256, 512)
254
255 # Make some channels have larger magnitude
256 weight[:64] *= 10
257 weight[64:128] *= 0.1
258
259 print(f"\nWeight matrix:")
260 print(f" Shape: {weight.shape}")
261 print(f" Overall range: [{weight.min():.2f}, {weight.max():.2f}]")
262 print(f" Channel 0 range: [{weight[0].min():.2f}, {weight[0].max():.2f}]")
263 print(f" Channel 100 range: [{weight[100].min():.2f}, {weight[100].max():.2f}]")
264
265 # Per-tensor quantization
266 quant_tensor, scale_tensor = symmetric_quantize(weight, num_bits=8)
267 dequant_tensor = symmetric_dequantize(quant_tensor, scale_tensor)
268 metrics_tensor = compute_quantization_error(weight, dequant_tensor)
269
270 print(f"\nPer-tensor quantization:")
271 print(f" Single scale: {scale_tensor:.6f}")
272 print(f" Max error: {metrics_tensor['max_error']:.4f}")
273 print(f" SNR: {metrics_tensor['snr_db']:.2f} dB")
274
275 # Per-channel quantization
276 quant_channel, scales_channel = per_channel_quantize(weight, num_bits=8, dim=0)
277 dequant_channel = per_channel_dequantize(quant_channel, scales_channel, dim=0)
278 metrics_channel = compute_quantization_error(weight, dequant_channel)
279
280 print(f"\nPer-channel quantization:")
281 print(f" Scale range: [{scales_channel.min():.6f}, {scales_channel.max():.6f}]")
282 print(f" Max error: {metrics_channel['max_error']:.4f}")
283 print(f" SNR: {metrics_channel['snr_db']:.2f} dB")
284 print(f" Improvement: {metrics_channel['snr_db'] - metrics_tensor['snr_db']:.2f} dB")
285
286
287def demo_bit_depth_comparison():
288 """Compare different quantization bit depths."""
289 print("\n" + "=" * 60)
290 print("DEMO 4: Bit Depth Comparison")
291 print("=" * 60)
292
293 torch.manual_seed(42)
294 tensor = torch.randn(10000) * 5
295
296 print(f"\nOriginal tensor: {tensor.numel()} values")
297 print(f" FP32 size: {tensor.numel() * 4 / 1024:.2f} KB")
298
299 print("\n" + "-" * 60)
300 print(f"{'Bits':<8} {'Size (KB)':<12} {'SNR (dB)':<12} {'Compression':<12}")
301 print("-" * 60)
302
303 for num_bits in [2, 4, 8, 16]:
304 quantized, scale = symmetric_quantize(tensor, num_bits=num_bits)
305 dequantized = symmetric_dequantize(quantized, scale)
306
307 metrics = compute_quantization_error(tensor, dequantized)
308
309 # Approximate size (quantized values + scale)
310 size_kb = (tensor.numel() * num_bits / 8 + 4) / 1024
311 compression = (tensor.numel() * 4) / (tensor.numel() * num_bits / 8 + 4)
312
313 print(f"{num_bits:<8} {size_kb:<12.2f} {metrics['snr_db']:<12.2f} {compression:<12.2f}x")
314
315
316def demo_calibration():
317 """Demonstrate calibration for quantization."""
318 print("\n" + "=" * 60)
319 print("DEMO 5: Calibration Strategies")
320 print("=" * 60)
321
322 # Generate data with outliers
323 torch.manual_seed(42)
324 tensor = torch.randn(1000) * 2
325 tensor[torch.randint(0, 1000, (10,))] = torch.randn(10) * 50 # Outliers
326
327 print(f"\nData statistics:")
328 print(f" Min: {tensor.min():.2f}, Max: {tensor.max():.2f}")
329 print(f" Mean: {tensor.mean():.2f}, Std: {tensor.std():.2f}")
330 print(f" 99th percentile: {torch.quantile(tensor.abs(), 0.99):.2f}")
331
332 # Strategy 1: Min-max (uses full range)
333 quant1, scale1 = symmetric_quantize(tensor, num_bits=8)
334 dequant1 = symmetric_dequantize(quant1, scale1)
335 metrics1 = compute_quantization_error(tensor, dequant1)
336
337 print(f"\nMin-max calibration:")
338 print(f" Clipping range: ±{tensor.abs().max():.2f}")
339 print(f" SNR: {metrics1['snr_db']:.2f} dB")
340
341 # Strategy 2: Percentile-based (clip outliers)
342 percentile = 99
343 clip_value = torch.quantile(tensor.abs(), percentile / 100)
344 tensor_clipped = tensor.clamp(-clip_value, clip_value)
345
346 quant2, scale2 = symmetric_quantize(tensor_clipped, num_bits=8)
347 dequant2 = symmetric_dequantize(quant2, scale2)
348 metrics2 = compute_quantization_error(tensor_clipped, dequant2)
349
350 print(f"\n{percentile}th percentile calibration:")
351 print(f" Clipping range: ±{clip_value:.2f}")
352 print(f" SNR: {metrics2['snr_db']:.2f} dB")
353 print(f" Values clipped: {(tensor.abs() > clip_value).sum().item()}")
354
355
356def demo_quantized_matmul():
357 """Demonstrate quantized matrix multiplication."""
358 print("\n" + "=" * 60)
359 print("DEMO 6: Quantized Matrix Multiplication")
360 print("=" * 60)
361
362 # Create matrices
363 torch.manual_seed(42)
364 A = torch.randn(64, 128) * 2
365 B = torch.randn(128, 256) * 3
366
367 # FP32 matmul
368 C_fp32 = torch.matmul(A, B)
369
370 print(f"\nMatrix shapes: A {A.shape} @ B {B.shape} = C {C_fp32.shape}")
371
372 # Quantize matrices
373 A_quant, A_scale = symmetric_quantize(A, num_bits=8)
374 B_quant, B_scale = symmetric_quantize(B, num_bits=8)
375
376 # Quantized matmul (in practice, use INT8 kernel)
377 # Here we dequantize for demonstration
378 A_dequant = symmetric_dequantize(A_quant, A_scale)
379 B_dequant = symmetric_dequantize(B_quant, B_scale)
380 C_quant = torch.matmul(A_dequant, B_dequant)
381
382 # Compare results
383 error = (C_fp32 - C_quant).abs()
384
385 print(f"\nResults:")
386 print(f" FP32 result range: [{C_fp32.min():.2f}, {C_fp32.max():.2f}]")
387 print(f" INT8 result range: [{C_quant.min():.2f}, {C_quant.max():.2f}]")
388 print(f" Max error: {error.max():.4f}")
389 print(f" Mean error: {error.mean():.4f}")
390 print(f" Relative error: {(error / C_fp32.abs()).mean():.4%}")
391
392 # Memory savings
393 fp32_size = (A.numel() + B.numel() + C_fp32.numel()) * 4
394 int8_size = (A.numel() + B.numel()) * 1 + C_fp32.numel() * 4 + 8 # scales
395 print(f"\nMemory usage:")
396 print(f" FP32: {fp32_size / 1024:.2f} KB")
397 print(f" INT8: {int8_size / 1024:.2f} KB")
398 print(f" Savings: {(1 - int8_size/fp32_size) * 100:.1f}%")
399
400
401if __name__ == "__main__":
402 print("\n" + "=" * 60)
403 print("Foundation Models: Quantization")
404 print("=" * 60)
405
406 demo_symmetric_quantization()
407 demo_asymmetric_quantization()
408 demo_per_channel_quantization()
409 demo_bit_depth_comparison()
410 demo_calibration()
411 demo_quantized_matmul()
412
413 print("\n" + "=" * 60)
414 print("Key Takeaways:")
415 print("=" * 60)
416 print("1. Quantization maps FP32 → INT8/INT4 for compression")
417 print("2. Symmetric: best for weights (centered at 0)")
418 print("3. Asymmetric: best for activations (arbitrary range)")
419 print("4. Per-channel: better accuracy for heterogeneous data")
420 print("5. Calibration: clip outliers to reduce quantization error")
421 print("6. INT8 provides ~4x compression with minimal accuracy loss")
422 print("=" * 60)