1"""
218. Efficient Attention Mechanisms
3
4Implementation of various attention mechanisms including:
5- Standard Multi-Head Attention
6- Flash Attention (via PyTorch 2.0+)
7- Sparse Attention patterns
8- Position encodings (Sinusoidal, RoPE, ALiBi)
9"""
10
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14import math
15import matplotlib.pyplot as plt
16import numpy as np
17import time
18
19print("=" * 60)
20print("Efficient Attention Mechanisms")
21print("=" * 60)
22
23device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24print(f"Using device: {device}")
25print(f"PyTorch version: {torch.__version__}")
26
27
28# ============================================
29# 1. Standard Multi-Head Attention
30# ============================================
31print("\n[1] Standard Multi-Head Attention")
32print("-" * 40)
33
34
35class MultiHeadAttention(nn.Module):
36 """Standard Multi-Head Attention implementation"""
37 def __init__(self, d_model, num_heads, dropout=0.1):
38 super().__init__()
39 assert d_model % num_heads == 0
40
41 self.d_model = d_model
42 self.num_heads = num_heads
43 self.head_dim = d_model // num_heads
44 self.scale = math.sqrt(self.head_dim)
45
46 self.W_q = nn.Linear(d_model, d_model, bias=False)
47 self.W_k = nn.Linear(d_model, d_model, bias=False)
48 self.W_v = nn.Linear(d_model, d_model, bias=False)
49 self.W_o = nn.Linear(d_model, d_model, bias=False)
50
51 self.dropout = nn.Dropout(dropout)
52
53 def forward(self, query, key, value, mask=None, return_attention=False):
54 batch_size, seq_len, _ = query.size()
55
56 # Linear projections
57 Q = self.W_q(query)
58 K = self.W_k(key)
59 V = self.W_v(value)
60
61 # Split into heads: (batch, seq, d_model) -> (batch, heads, seq, head_dim)
62 Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
63 K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
64 V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
65
66 # Attention scores
67 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
68
69 # Apply mask
70 if mask is not None:
71 scores = scores.masked_fill(mask == 0, float('-inf'))
72
73 # Softmax and dropout
74 attention_weights = F.softmax(scores, dim=-1)
75 attention_weights = self.dropout(attention_weights)
76
77 # Apply attention to values
78 attention_output = torch.matmul(attention_weights, V)
79
80 # Concatenate heads
81 attention_output = attention_output.transpose(1, 2).contiguous()
82 attention_output = attention_output.view(batch_size, seq_len, self.d_model)
83
84 # Output projection
85 output = self.W_o(attention_output)
86
87 if return_attention:
88 return output, attention_weights
89 return output
90
91
92# Test
93mha = MultiHeadAttention(d_model=512, num_heads=8)
94x = torch.randn(2, 100, 512)
95out = mha(x, x, x)
96print(f"Input: {x.shape}")
97print(f"Output: {out.shape}")
98
99
100# ============================================
101# 2. PyTorch 2.0+ Scaled Dot-Product Attention
102# ============================================
103print("\n[2] PyTorch Scaled Dot-Product Attention")
104print("-" * 40)
105
106
107class EfficientMultiHeadAttention(nn.Module):
108 """Multi-Head Attention using PyTorch's scaled_dot_product_attention
109
110 Automatically uses Flash Attention when available
111 """
112 def __init__(self, d_model, num_heads, dropout=0.1):
113 super().__init__()
114 assert d_model % num_heads == 0
115
116 self.d_model = d_model
117 self.num_heads = num_heads
118 self.head_dim = d_model // num_heads
119 self.dropout = dropout
120
121 self.W_q = nn.Linear(d_model, d_model, bias=False)
122 self.W_k = nn.Linear(d_model, d_model, bias=False)
123 self.W_v = nn.Linear(d_model, d_model, bias=False)
124 self.W_o = nn.Linear(d_model, d_model, bias=False)
125
126 def forward(self, query, key, value, mask=None, is_causal=False):
127 batch_size, seq_len, _ = query.size()
128
129 Q = self.W_q(query)
130 K = self.W_k(key)
131 V = self.W_v(value)
132
133 # Reshape for multi-head
134 Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
135 K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
136 V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
137
138 # Use PyTorch's efficient attention
139 dropout_p = self.dropout if self.training else 0.0
140 attention_output = F.scaled_dot_product_attention(
141 Q, K, V,
142 attn_mask=mask,
143 dropout_p=dropout_p,
144 is_causal=is_causal
145 )
146
147 # Reshape back
148 attention_output = attention_output.transpose(1, 2).contiguous()
149 attention_output = attention_output.view(batch_size, seq_len, self.d_model)
150
151 return self.W_o(attention_output)
152
153
154# Test efficient attention
155efficient_mha = EfficientMultiHeadAttention(d_model=512, num_heads=8)
156out_efficient = efficient_mha(x, x, x)
157print(f"Efficient MHA output: {out_efficient.shape}")
158
159
160# ============================================
161# 3. Attention Complexity Analysis
162# ============================================
163print("\n[3] Complexity Analysis")
164print("-" * 40)
165
166
167def analyze_complexity(seq_lengths, d_model=512, num_heads=8):
168 """Analyze time and memory complexity for different sequence lengths"""
169 results = []
170
171 for seq_len in seq_lengths:
172 # Theoretical complexity
173 time_complexity = seq_len ** 2 * d_model # O(n^2 * d)
174 space_complexity = seq_len ** 2 * num_heads # attention matrix
175
176 # Memory in GB (float32)
177 memory_gb = space_complexity * 4 / (1024 ** 3)
178
179 results.append({
180 'seq_len': seq_len,
181 'time_ops': time_complexity,
182 'memory_gb': memory_gb
183 })
184
185 return results
186
187
188seq_lengths = [128, 256, 512, 1024, 2048, 4096]
189complexity = analyze_complexity(seq_lengths)
190
191print("Sequence Length | Time Ops (M) | Memory (GB)")
192print("-" * 45)
193for r in complexity:
194 print(f"{r['seq_len']:>14} | {r['time_ops']/1e6:>11.2f} | {r['memory_gb']:.4f}")
195
196
197# ============================================
198# 4. Sparse Attention Patterns
199# ============================================
200print("\n[4] Sparse Attention Patterns")
201print("-" * 40)
202
203
204def create_local_mask(seq_len, window_size):
205 """Create local (sliding window) attention mask"""
206 mask = torch.zeros(seq_len, seq_len)
207 for i in range(seq_len):
208 start = max(0, i - window_size // 2)
209 end = min(seq_len, i + window_size // 2 + 1)
210 mask[i, start:end] = 1
211 return mask
212
213
214def create_strided_mask(seq_len, stride):
215 """Create strided attention mask"""
216 mask = torch.zeros(seq_len, seq_len)
217 for i in range(seq_len):
218 indices = list(range(0, seq_len, stride))
219 mask[i, indices] = 1
220 return mask
221
222
223def create_causal_mask(seq_len):
224 """Create causal (autoregressive) mask"""
225 return torch.tril(torch.ones(seq_len, seq_len))
226
227
228def create_bigbird_mask(seq_len, window_size=64, num_global=2, num_random=3):
229 """Create BigBird sparse attention mask"""
230 mask = torch.zeros(seq_len, seq_len)
231
232 # Local attention
233 for i in range(seq_len):
234 start = max(0, i - window_size // 2)
235 end = min(seq_len, i + window_size // 2 + 1)
236 mask[i, start:end] = 1
237
238 # Global tokens
239 mask[:num_global, :] = 1
240 mask[:, :num_global] = 1
241
242 # Random connections
243 for i in range(seq_len):
244 random_indices = torch.randperm(seq_len)[:num_random]
245 mask[i, random_indices] = 1
246
247 return mask
248
249
250# Visualize masks
251fig, axes = plt.subplots(2, 2, figsize=(10, 10))
252seq_len = 64
253
254masks = [
255 ("Local (window=16)", create_local_mask(seq_len, 16)),
256 ("Strided (stride=8)", create_strided_mask(seq_len, 8)),
257 ("Causal", create_causal_mask(seq_len)),
258 ("BigBird", create_bigbird_mask(seq_len, 16, 2, 3))
259]
260
261for ax, (name, mask) in zip(axes.flat, masks):
262 ax.imshow(mask, cmap='Blues')
263 ax.set_title(name)
264 ax.axis('off')
265
266plt.tight_layout()
267plt.savefig('attention_masks.png', dpi=150)
268plt.close()
269print("Attention masks saved to attention_masks.png")
270
271
272class LocalAttention(nn.Module):
273 """Local (Sliding Window) Attention"""
274 def __init__(self, d_model, num_heads, window_size=256):
275 super().__init__()
276 self.window_size = window_size
277 self.attention = MultiHeadAttention(d_model, num_heads)
278
279 def forward(self, x):
280 seq_len = x.size(1)
281 mask = create_local_mask(seq_len, self.window_size).to(x.device)
282 mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dims
283 return self.attention(x, x, x, mask=mask)
284
285
286# ============================================
287# 5. Position Encodings
288# ============================================
289print("\n[5] Position Encodings")
290print("-" * 40)
291
292
293class SinusoidalPositionalEncoding(nn.Module):
294 """Sinusoidal Position Encoding (original Transformer)"""
295 def __init__(self, d_model, max_len=5000):
296 super().__init__()
297
298 pe = torch.zeros(max_len, d_model)
299 position = torch.arange(0, max_len).unsqueeze(1).float()
300 div_term = torch.exp(torch.arange(0, d_model, 2).float() *
301 (-math.log(10000.0) / d_model))
302
303 pe[:, 0::2] = torch.sin(position * div_term)
304 pe[:, 1::2] = torch.cos(position * div_term)
305
306 self.register_buffer('pe', pe.unsqueeze(0))
307
308 def forward(self, x):
309 return x + self.pe[:, :x.size(1)]
310
311
312class LearnedPositionalEncoding(nn.Module):
313 """Learned Position Embeddings"""
314 def __init__(self, d_model, max_len=512):
315 super().__init__()
316 self.pos_embedding = nn.Embedding(max_len, d_model)
317
318 def forward(self, x):
319 seq_len = x.size(1)
320 positions = torch.arange(seq_len, device=x.device)
321 return x + self.pos_embedding(positions)
322
323
324class RotaryPositionalEmbedding(nn.Module):
325 """Rotary Position Embedding (RoPE)
326
327 Used in LLaMA, GPT-NeoX, etc.
328 """
329 def __init__(self, dim, max_len=2048, base=10000):
330 super().__init__()
331 self.dim = dim
332 inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
333 self.register_buffer('inv_freq', inv_freq)
334
335 self._set_cos_sin_cache(max_len)
336
337 def _set_cos_sin_cache(self, seq_len):
338 t = torch.arange(seq_len, device=self.inv_freq.device).float()
339 freqs = torch.einsum('i,j->ij', t, self.inv_freq)
340 emb = torch.cat((freqs, freqs), dim=-1)
341 self.register_buffer('cos_cached', emb.cos())
342 self.register_buffer('sin_cached', emb.sin())
343
344 def _rotate_half(self, x):
345 x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
346 return torch.cat((-x2, x1), dim=-1)
347
348 def forward(self, q, k, seq_len):
349 cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
350 sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
351
352 q_embed = (q * cos) + (self._rotate_half(q) * sin)
353 k_embed = (k * cos) + (self._rotate_half(k) * sin)
354
355 return q_embed, k_embed
356
357
358class ALiBiPositionalBias(nn.Module):
359 """Attention with Linear Biases (ALiBi)
360
361 Used in MPT, BLOOM, etc.
362 No learned parameters, good length extrapolation.
363 """
364 def __init__(self, num_heads):
365 super().__init__()
366 self.num_heads = num_heads
367 slopes = self._get_slopes(num_heads)
368 self.register_buffer('slopes', slopes)
369
370 def _get_slopes(self, n):
371 def get_slopes_power_of_2(n):
372 start = 2 ** (-(2 ** -(math.log2(n) - 3)))
373 ratio = start
374 return [start * ratio ** i for i in range(n)]
375
376 if math.log2(n).is_integer():
377 return torch.tensor(get_slopes_power_of_2(n))
378 else:
379 closest_power = 2 ** math.floor(math.log2(n))
380 return torch.tensor(
381 get_slopes_power_of_2(closest_power) +
382 get_slopes_power_of_2(2 * closest_power)[0::2][:n - closest_power]
383 )
384
385 def forward(self, seq_len, device):
386 positions = torch.arange(seq_len, device=device)
387 relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
388 relative_positions = relative_positions.abs().unsqueeze(0).float()
389
390 alibi = -self.slopes.unsqueeze(1).unsqueeze(1).to(device) * relative_positions
391
392 return alibi
393
394
395# Visualize position encodings
396fig, axes = plt.subplots(1, 3, figsize=(15, 5))
397
398# Sinusoidal
399sin_pe = SinusoidalPositionalEncoding(64)
400pe_matrix = sin_pe.pe[0, :100, :].numpy()
401axes[0].imshow(pe_matrix.T, aspect='auto', cmap='RdBu')
402axes[0].set_title('Sinusoidal PE')
403axes[0].set_xlabel('Position')
404axes[0].set_ylabel('Dimension')
405
406# ALiBi
407alibi = ALiBiPositionalBias(8)
408alibi_bias = alibi(100, 'cpu')
409axes[1].imshow(alibi_bias[0].numpy(), aspect='auto', cmap='RdBu')
410axes[1].set_title('ALiBi Bias (Head 0)')
411axes[1].set_xlabel('Key Position')
412axes[1].set_ylabel('Query Position')
413
414# ALiBi all heads
415axes[2].imshow(alibi_bias.mean(0).numpy(), aspect='auto', cmap='RdBu')
416axes[2].set_title('ALiBi Bias (Mean)')
417axes[2].set_xlabel('Key Position')
418axes[2].set_ylabel('Query Position')
419
420plt.tight_layout()
421plt.savefig('position_encodings.png', dpi=150)
422plt.close()
423print("Position encodings visualization saved to position_encodings.png")
424
425
426# ============================================
427# 6. Attention with ALiBi
428# ============================================
429print("\n[6] Attention with ALiBi")
430print("-" * 40)
431
432
433class ALiBiMultiHeadAttention(nn.Module):
434 """Multi-Head Attention with ALiBi position bias"""
435 def __init__(self, d_model, num_heads, dropout=0.1):
436 super().__init__()
437 assert d_model % num_heads == 0
438
439 self.d_model = d_model
440 self.num_heads = num_heads
441 self.head_dim = d_model // num_heads
442 self.scale = math.sqrt(self.head_dim)
443
444 self.W_q = nn.Linear(d_model, d_model, bias=False)
445 self.W_k = nn.Linear(d_model, d_model, bias=False)
446 self.W_v = nn.Linear(d_model, d_model, bias=False)
447 self.W_o = nn.Linear(d_model, d_model, bias=False)
448
449 self.alibi = ALiBiPositionalBias(num_heads)
450 self.dropout = nn.Dropout(dropout)
451
452 def forward(self, query, key, value, mask=None):
453 batch_size, seq_len, _ = query.size()
454
455 Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
456 K = self.W_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
457 V = self.W_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
458
459 # Attention scores with ALiBi bias
460 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
461 alibi_bias = self.alibi(seq_len, query.device)
462 scores = scores + alibi_bias
463
464 if mask is not None:
465 scores = scores.masked_fill(mask == 0, float('-inf'))
466
467 attention_weights = F.softmax(scores, dim=-1)
468 attention_weights = self.dropout(attention_weights)
469
470 attention_output = torch.matmul(attention_weights, V)
471 attention_output = attention_output.transpose(1, 2).contiguous()
472 attention_output = attention_output.view(batch_size, seq_len, self.d_model)
473
474 return self.W_o(attention_output)
475
476
477# Test ALiBi attention
478alibi_mha = ALiBiMultiHeadAttention(d_model=512, num_heads=8)
479out_alibi = alibi_mha(x, x, x)
480print(f"ALiBi MHA output: {out_alibi.shape}")
481
482
483# ============================================
484# 7. Benchmark
485# ============================================
486print("\n[7] Performance Benchmark")
487print("-" * 40)
488
489
490def benchmark_attention(attn_module, batch_size, seq_len, d_model, num_runs=10):
491 """Benchmark attention module"""
492 x = torch.randn(batch_size, seq_len, d_model, device=device)
493
494 # Warmup
495 with torch.no_grad():
496 for _ in range(3):
497 _ = attn_module(x, x, x)
498 if device.type == 'cuda':
499 torch.cuda.synchronize()
500
501 # Timing
502 start = time.time()
503 with torch.no_grad():
504 for _ in range(num_runs):
505 _ = attn_module(x, x, x)
506 if device.type == 'cuda':
507 torch.cuda.synchronize()
508 elapsed = (time.time() - start) / num_runs
509
510 return elapsed * 1000 # ms
511
512
513# Only benchmark if not too slow
514if device.type == 'cuda':
515 print("\nBenchmarking on GPU...")
516 configs = [
517 (8, 256, 512),
518 (8, 512, 512),
519 (8, 1024, 512),
520 ]
521
522 results = []
523 for batch, seq, dim in configs:
524 standard_mha = MultiHeadAttention(dim, 8).to(device)
525 efficient_mha = EfficientMultiHeadAttention(dim, 8).to(device)
526
527 time_standard = benchmark_attention(standard_mha, batch, seq, dim)
528 time_efficient = benchmark_attention(efficient_mha, batch, seq, dim)
529
530 results.append({
531 'config': f"({batch}, {seq}, {dim})",
532 'standard': time_standard,
533 'efficient': time_efficient,
534 'speedup': time_standard / time_efficient
535 })
536
537 print("\nConfig | Standard (ms) | Efficient (ms) | Speedup")
538 print("-" * 60)
539 for r in results:
540 print(f"{r['config']:16} | {r['standard']:12.2f} | {r['efficient']:13.2f} | {r['speedup']:.2f}x")
541else:
542 print("GPU not available. Skipping benchmark.")
543
544
545# ============================================
546# 8. Attention Visualization
547# ============================================
548print("\n[8] Attention Visualization")
549print("-" * 40)
550
551
552def visualize_attention_weights(attention_weights, tokens=None, filename='attention_vis.png'):
553 """Visualize attention weights heatmap"""
554 num_heads = attention_weights.size(1)
555 ncols = 4
556 nrows = (num_heads + ncols - 1) // ncols
557
558 fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
559 axes = axes.flatten()
560
561 for head in range(num_heads):
562 weights = attention_weights[0, head].cpu().detach().numpy()
563 ax = axes[head]
564 im = ax.imshow(weights, cmap='Blues')
565 ax.set_title(f'Head {head}')
566
567 if tokens and len(tokens) <= 10:
568 ax.set_xticks(range(len(tokens)))
569 ax.set_yticks(range(len(tokens)))
570 ax.set_xticklabels(tokens, rotation=45, ha='right')
571 ax.set_yticklabels(tokens)
572 else:
573 ax.set_xlabel('Key')
574 ax.set_ylabel('Query')
575
576 for i in range(num_heads, len(axes)):
577 axes[i].axis('off')
578
579 plt.tight_layout()
580 plt.savefig(filename, dpi=150)
581 plt.close()
582 print(f"Attention visualization saved to {filename}")
583
584
585# Generate sample attention
586mha_vis = MultiHeadAttention(d_model=64, num_heads=8)
587x_vis = torch.randn(1, 10, 64)
588_, attn_weights = mha_vis(x_vis, x_vis, x_vis, return_attention=True)
589tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat', 'and', 'then', 'left', '.']
590visualize_attention_weights(attn_weights, tokens, 'attention_heads.png')
591
592
593# ============================================
594# Summary
595# ============================================
596print("\n" + "=" * 60)
597print("Efficient Attention Summary")
598print("=" * 60)
599
600summary = """
601Key Concepts:
6021. Standard Attention: O(n^2) time and space
6032. Flash Attention: O(n^2) time, O(n) space via tiling
6043. Sparse Attention: O(n * k) where k << n
605
606Attention Patterns:
607- Local: Sliding window attention
608- Strided: Fixed-stride sparse pattern
609- BigBird: Local + global + random
610- Causal: Autoregressive masking
611
612Position Encodings:
613- Sinusoidal: Fixed, no parameters
614- Learned: Trainable embeddings
615- RoPE: Rotary, good relative position
616- ALiBi: Linear bias, best extrapolation
617
618PyTorch Tips:
6191. Use F.scaled_dot_product_attention (PyTorch 2.0+)
6202. Enable Flash Attention when possible
6213. Use is_causal=True for autoregressive
622
623Memory Comparison (seq_len=4096, heads=12):
624- Standard: ~1.5 GB (attention matrix)
625- Flash: ~0.1 GB (no full matrix storage)
626
627Output Files:
628- attention_masks.png: Sparse attention patterns
629- position_encodings.png: PE visualizations
630- attention_heads.png: Multi-head attention weights
631"""
632print(summary)
633print("=" * 60)