18_efficient_attention.py

Download
python 634 lines 19.6 KB
  1"""
  218. Efficient Attention Mechanisms
  3
  4Implementation of various attention mechanisms including:
  5- Standard Multi-Head Attention
  6- Flash Attention (via PyTorch 2.0+)
  7- Sparse Attention patterns
  8- Position encodings (Sinusoidal, RoPE, ALiBi)
  9"""
 10
 11import torch
 12import torch.nn as nn
 13import torch.nn.functional as F
 14import math
 15import matplotlib.pyplot as plt
 16import numpy as np
 17import time
 18
 19print("=" * 60)
 20print("Efficient Attention Mechanisms")
 21print("=" * 60)
 22
 23device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 24print(f"Using device: {device}")
 25print(f"PyTorch version: {torch.__version__}")
 26
 27
 28# ============================================
 29# 1. Standard Multi-Head Attention
 30# ============================================
 31print("\n[1] Standard Multi-Head Attention")
 32print("-" * 40)
 33
 34
 35class MultiHeadAttention(nn.Module):
 36    """Standard Multi-Head Attention implementation"""
 37    def __init__(self, d_model, num_heads, dropout=0.1):
 38        super().__init__()
 39        assert d_model % num_heads == 0
 40
 41        self.d_model = d_model
 42        self.num_heads = num_heads
 43        self.head_dim = d_model // num_heads
 44        self.scale = math.sqrt(self.head_dim)
 45
 46        self.W_q = nn.Linear(d_model, d_model, bias=False)
 47        self.W_k = nn.Linear(d_model, d_model, bias=False)
 48        self.W_v = nn.Linear(d_model, d_model, bias=False)
 49        self.W_o = nn.Linear(d_model, d_model, bias=False)
 50
 51        self.dropout = nn.Dropout(dropout)
 52
 53    def forward(self, query, key, value, mask=None, return_attention=False):
 54        batch_size, seq_len, _ = query.size()
 55
 56        # Linear projections
 57        Q = self.W_q(query)
 58        K = self.W_k(key)
 59        V = self.W_v(value)
 60
 61        # Split into heads: (batch, seq, d_model) -> (batch, heads, seq, head_dim)
 62        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
 63        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
 64        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
 65
 66        # Attention scores
 67        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
 68
 69        # Apply mask
 70        if mask is not None:
 71            scores = scores.masked_fill(mask == 0, float('-inf'))
 72
 73        # Softmax and dropout
 74        attention_weights = F.softmax(scores, dim=-1)
 75        attention_weights = self.dropout(attention_weights)
 76
 77        # Apply attention to values
 78        attention_output = torch.matmul(attention_weights, V)
 79
 80        # Concatenate heads
 81        attention_output = attention_output.transpose(1, 2).contiguous()
 82        attention_output = attention_output.view(batch_size, seq_len, self.d_model)
 83
 84        # Output projection
 85        output = self.W_o(attention_output)
 86
 87        if return_attention:
 88            return output, attention_weights
 89        return output
 90
 91
 92# Test
 93mha = MultiHeadAttention(d_model=512, num_heads=8)
 94x = torch.randn(2, 100, 512)
 95out = mha(x, x, x)
 96print(f"Input: {x.shape}")
 97print(f"Output: {out.shape}")
 98
 99
100# ============================================
101# 2. PyTorch 2.0+ Scaled Dot-Product Attention
102# ============================================
103print("\n[2] PyTorch Scaled Dot-Product Attention")
104print("-" * 40)
105
106
107class EfficientMultiHeadAttention(nn.Module):
108    """Multi-Head Attention using PyTorch's scaled_dot_product_attention
109
110    Automatically uses Flash Attention when available
111    """
112    def __init__(self, d_model, num_heads, dropout=0.1):
113        super().__init__()
114        assert d_model % num_heads == 0
115
116        self.d_model = d_model
117        self.num_heads = num_heads
118        self.head_dim = d_model // num_heads
119        self.dropout = dropout
120
121        self.W_q = nn.Linear(d_model, d_model, bias=False)
122        self.W_k = nn.Linear(d_model, d_model, bias=False)
123        self.W_v = nn.Linear(d_model, d_model, bias=False)
124        self.W_o = nn.Linear(d_model, d_model, bias=False)
125
126    def forward(self, query, key, value, mask=None, is_causal=False):
127        batch_size, seq_len, _ = query.size()
128
129        Q = self.W_q(query)
130        K = self.W_k(key)
131        V = self.W_v(value)
132
133        # Reshape for multi-head
134        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
135        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
136        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
137
138        # Use PyTorch's efficient attention
139        dropout_p = self.dropout if self.training else 0.0
140        attention_output = F.scaled_dot_product_attention(
141            Q, K, V,
142            attn_mask=mask,
143            dropout_p=dropout_p,
144            is_causal=is_causal
145        )
146
147        # Reshape back
148        attention_output = attention_output.transpose(1, 2).contiguous()
149        attention_output = attention_output.view(batch_size, seq_len, self.d_model)
150
151        return self.W_o(attention_output)
152
153
154# Test efficient attention
155efficient_mha = EfficientMultiHeadAttention(d_model=512, num_heads=8)
156out_efficient = efficient_mha(x, x, x)
157print(f"Efficient MHA output: {out_efficient.shape}")
158
159
160# ============================================
161# 3. Attention Complexity Analysis
162# ============================================
163print("\n[3] Complexity Analysis")
164print("-" * 40)
165
166
167def analyze_complexity(seq_lengths, d_model=512, num_heads=8):
168    """Analyze time and memory complexity for different sequence lengths"""
169    results = []
170
171    for seq_len in seq_lengths:
172        # Theoretical complexity
173        time_complexity = seq_len ** 2 * d_model  # O(n^2 * d)
174        space_complexity = seq_len ** 2 * num_heads  # attention matrix
175
176        # Memory in GB (float32)
177        memory_gb = space_complexity * 4 / (1024 ** 3)
178
179        results.append({
180            'seq_len': seq_len,
181            'time_ops': time_complexity,
182            'memory_gb': memory_gb
183        })
184
185    return results
186
187
188seq_lengths = [128, 256, 512, 1024, 2048, 4096]
189complexity = analyze_complexity(seq_lengths)
190
191print("Sequence Length | Time Ops (M) | Memory (GB)")
192print("-" * 45)
193for r in complexity:
194    print(f"{r['seq_len']:>14} | {r['time_ops']/1e6:>11.2f} | {r['memory_gb']:.4f}")
195
196
197# ============================================
198# 4. Sparse Attention Patterns
199# ============================================
200print("\n[4] Sparse Attention Patterns")
201print("-" * 40)
202
203
204def create_local_mask(seq_len, window_size):
205    """Create local (sliding window) attention mask"""
206    mask = torch.zeros(seq_len, seq_len)
207    for i in range(seq_len):
208        start = max(0, i - window_size // 2)
209        end = min(seq_len, i + window_size // 2 + 1)
210        mask[i, start:end] = 1
211    return mask
212
213
214def create_strided_mask(seq_len, stride):
215    """Create strided attention mask"""
216    mask = torch.zeros(seq_len, seq_len)
217    for i in range(seq_len):
218        indices = list(range(0, seq_len, stride))
219        mask[i, indices] = 1
220    return mask
221
222
223def create_causal_mask(seq_len):
224    """Create causal (autoregressive) mask"""
225    return torch.tril(torch.ones(seq_len, seq_len))
226
227
228def create_bigbird_mask(seq_len, window_size=64, num_global=2, num_random=3):
229    """Create BigBird sparse attention mask"""
230    mask = torch.zeros(seq_len, seq_len)
231
232    # Local attention
233    for i in range(seq_len):
234        start = max(0, i - window_size // 2)
235        end = min(seq_len, i + window_size // 2 + 1)
236        mask[i, start:end] = 1
237
238    # Global tokens
239    mask[:num_global, :] = 1
240    mask[:, :num_global] = 1
241
242    # Random connections
243    for i in range(seq_len):
244        random_indices = torch.randperm(seq_len)[:num_random]
245        mask[i, random_indices] = 1
246
247    return mask
248
249
250# Visualize masks
251fig, axes = plt.subplots(2, 2, figsize=(10, 10))
252seq_len = 64
253
254masks = [
255    ("Local (window=16)", create_local_mask(seq_len, 16)),
256    ("Strided (stride=8)", create_strided_mask(seq_len, 8)),
257    ("Causal", create_causal_mask(seq_len)),
258    ("BigBird", create_bigbird_mask(seq_len, 16, 2, 3))
259]
260
261for ax, (name, mask) in zip(axes.flat, masks):
262    ax.imshow(mask, cmap='Blues')
263    ax.set_title(name)
264    ax.axis('off')
265
266plt.tight_layout()
267plt.savefig('attention_masks.png', dpi=150)
268plt.close()
269print("Attention masks saved to attention_masks.png")
270
271
272class LocalAttention(nn.Module):
273    """Local (Sliding Window) Attention"""
274    def __init__(self, d_model, num_heads, window_size=256):
275        super().__init__()
276        self.window_size = window_size
277        self.attention = MultiHeadAttention(d_model, num_heads)
278
279    def forward(self, x):
280        seq_len = x.size(1)
281        mask = create_local_mask(seq_len, self.window_size).to(x.device)
282        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dims
283        return self.attention(x, x, x, mask=mask)
284
285
286# ============================================
287# 5. Position Encodings
288# ============================================
289print("\n[5] Position Encodings")
290print("-" * 40)
291
292
293class SinusoidalPositionalEncoding(nn.Module):
294    """Sinusoidal Position Encoding (original Transformer)"""
295    def __init__(self, d_model, max_len=5000):
296        super().__init__()
297
298        pe = torch.zeros(max_len, d_model)
299        position = torch.arange(0, max_len).unsqueeze(1).float()
300        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
301                            (-math.log(10000.0) / d_model))
302
303        pe[:, 0::2] = torch.sin(position * div_term)
304        pe[:, 1::2] = torch.cos(position * div_term)
305
306        self.register_buffer('pe', pe.unsqueeze(0))
307
308    def forward(self, x):
309        return x + self.pe[:, :x.size(1)]
310
311
312class LearnedPositionalEncoding(nn.Module):
313    """Learned Position Embeddings"""
314    def __init__(self, d_model, max_len=512):
315        super().__init__()
316        self.pos_embedding = nn.Embedding(max_len, d_model)
317
318    def forward(self, x):
319        seq_len = x.size(1)
320        positions = torch.arange(seq_len, device=x.device)
321        return x + self.pos_embedding(positions)
322
323
324class RotaryPositionalEmbedding(nn.Module):
325    """Rotary Position Embedding (RoPE)
326
327    Used in LLaMA, GPT-NeoX, etc.
328    """
329    def __init__(self, dim, max_len=2048, base=10000):
330        super().__init__()
331        self.dim = dim
332        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
333        self.register_buffer('inv_freq', inv_freq)
334
335        self._set_cos_sin_cache(max_len)
336
337    def _set_cos_sin_cache(self, seq_len):
338        t = torch.arange(seq_len, device=self.inv_freq.device).float()
339        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
340        emb = torch.cat((freqs, freqs), dim=-1)
341        self.register_buffer('cos_cached', emb.cos())
342        self.register_buffer('sin_cached', emb.sin())
343
344    def _rotate_half(self, x):
345        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
346        return torch.cat((-x2, x1), dim=-1)
347
348    def forward(self, q, k, seq_len):
349        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
350        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
351
352        q_embed = (q * cos) + (self._rotate_half(q) * sin)
353        k_embed = (k * cos) + (self._rotate_half(k) * sin)
354
355        return q_embed, k_embed
356
357
358class ALiBiPositionalBias(nn.Module):
359    """Attention with Linear Biases (ALiBi)
360
361    Used in MPT, BLOOM, etc.
362    No learned parameters, good length extrapolation.
363    """
364    def __init__(self, num_heads):
365        super().__init__()
366        self.num_heads = num_heads
367        slopes = self._get_slopes(num_heads)
368        self.register_buffer('slopes', slopes)
369
370    def _get_slopes(self, n):
371        def get_slopes_power_of_2(n):
372            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
373            ratio = start
374            return [start * ratio ** i for i in range(n)]
375
376        if math.log2(n).is_integer():
377            return torch.tensor(get_slopes_power_of_2(n))
378        else:
379            closest_power = 2 ** math.floor(math.log2(n))
380            return torch.tensor(
381                get_slopes_power_of_2(closest_power) +
382                get_slopes_power_of_2(2 * closest_power)[0::2][:n - closest_power]
383            )
384
385    def forward(self, seq_len, device):
386        positions = torch.arange(seq_len, device=device)
387        relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
388        relative_positions = relative_positions.abs().unsqueeze(0).float()
389
390        alibi = -self.slopes.unsqueeze(1).unsqueeze(1).to(device) * relative_positions
391
392        return alibi
393
394
395# Visualize position encodings
396fig, axes = plt.subplots(1, 3, figsize=(15, 5))
397
398# Sinusoidal
399sin_pe = SinusoidalPositionalEncoding(64)
400pe_matrix = sin_pe.pe[0, :100, :].numpy()
401axes[0].imshow(pe_matrix.T, aspect='auto', cmap='RdBu')
402axes[0].set_title('Sinusoidal PE')
403axes[0].set_xlabel('Position')
404axes[0].set_ylabel('Dimension')
405
406# ALiBi
407alibi = ALiBiPositionalBias(8)
408alibi_bias = alibi(100, 'cpu')
409axes[1].imshow(alibi_bias[0].numpy(), aspect='auto', cmap='RdBu')
410axes[1].set_title('ALiBi Bias (Head 0)')
411axes[1].set_xlabel('Key Position')
412axes[1].set_ylabel('Query Position')
413
414# ALiBi all heads
415axes[2].imshow(alibi_bias.mean(0).numpy(), aspect='auto', cmap='RdBu')
416axes[2].set_title('ALiBi Bias (Mean)')
417axes[2].set_xlabel('Key Position')
418axes[2].set_ylabel('Query Position')
419
420plt.tight_layout()
421plt.savefig('position_encodings.png', dpi=150)
422plt.close()
423print("Position encodings visualization saved to position_encodings.png")
424
425
426# ============================================
427# 6. Attention with ALiBi
428# ============================================
429print("\n[6] Attention with ALiBi")
430print("-" * 40)
431
432
433class ALiBiMultiHeadAttention(nn.Module):
434    """Multi-Head Attention with ALiBi position bias"""
435    def __init__(self, d_model, num_heads, dropout=0.1):
436        super().__init__()
437        assert d_model % num_heads == 0
438
439        self.d_model = d_model
440        self.num_heads = num_heads
441        self.head_dim = d_model // num_heads
442        self.scale = math.sqrt(self.head_dim)
443
444        self.W_q = nn.Linear(d_model, d_model, bias=False)
445        self.W_k = nn.Linear(d_model, d_model, bias=False)
446        self.W_v = nn.Linear(d_model, d_model, bias=False)
447        self.W_o = nn.Linear(d_model, d_model, bias=False)
448
449        self.alibi = ALiBiPositionalBias(num_heads)
450        self.dropout = nn.Dropout(dropout)
451
452    def forward(self, query, key, value, mask=None):
453        batch_size, seq_len, _ = query.size()
454
455        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
456        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
457        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
458
459        # Attention scores with ALiBi bias
460        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
461        alibi_bias = self.alibi(seq_len, query.device)
462        scores = scores + alibi_bias
463
464        if mask is not None:
465            scores = scores.masked_fill(mask == 0, float('-inf'))
466
467        attention_weights = F.softmax(scores, dim=-1)
468        attention_weights = self.dropout(attention_weights)
469
470        attention_output = torch.matmul(attention_weights, V)
471        attention_output = attention_output.transpose(1, 2).contiguous()
472        attention_output = attention_output.view(batch_size, seq_len, self.d_model)
473
474        return self.W_o(attention_output)
475
476
477# Test ALiBi attention
478alibi_mha = ALiBiMultiHeadAttention(d_model=512, num_heads=8)
479out_alibi = alibi_mha(x, x, x)
480print(f"ALiBi MHA output: {out_alibi.shape}")
481
482
483# ============================================
484# 7. Benchmark
485# ============================================
486print("\n[7] Performance Benchmark")
487print("-" * 40)
488
489
490def benchmark_attention(attn_module, batch_size, seq_len, d_model, num_runs=10):
491    """Benchmark attention module"""
492    x = torch.randn(batch_size, seq_len, d_model, device=device)
493
494    # Warmup
495    with torch.no_grad():
496        for _ in range(3):
497            _ = attn_module(x, x, x)
498    if device.type == 'cuda':
499        torch.cuda.synchronize()
500
501    # Timing
502    start = time.time()
503    with torch.no_grad():
504        for _ in range(num_runs):
505            _ = attn_module(x, x, x)
506    if device.type == 'cuda':
507        torch.cuda.synchronize()
508    elapsed = (time.time() - start) / num_runs
509
510    return elapsed * 1000  # ms
511
512
513# Only benchmark if not too slow
514if device.type == 'cuda':
515    print("\nBenchmarking on GPU...")
516    configs = [
517        (8, 256, 512),
518        (8, 512, 512),
519        (8, 1024, 512),
520    ]
521
522    results = []
523    for batch, seq, dim in configs:
524        standard_mha = MultiHeadAttention(dim, 8).to(device)
525        efficient_mha = EfficientMultiHeadAttention(dim, 8).to(device)
526
527        time_standard = benchmark_attention(standard_mha, batch, seq, dim)
528        time_efficient = benchmark_attention(efficient_mha, batch, seq, dim)
529
530        results.append({
531            'config': f"({batch}, {seq}, {dim})",
532            'standard': time_standard,
533            'efficient': time_efficient,
534            'speedup': time_standard / time_efficient
535        })
536
537    print("\nConfig            | Standard (ms) | Efficient (ms) | Speedup")
538    print("-" * 60)
539    for r in results:
540        print(f"{r['config']:16} | {r['standard']:12.2f} | {r['efficient']:13.2f} | {r['speedup']:.2f}x")
541else:
542    print("GPU not available. Skipping benchmark.")
543
544
545# ============================================
546# 8. Attention Visualization
547# ============================================
548print("\n[8] Attention Visualization")
549print("-" * 40)
550
551
552def visualize_attention_weights(attention_weights, tokens=None, filename='attention_vis.png'):
553    """Visualize attention weights heatmap"""
554    num_heads = attention_weights.size(1)
555    ncols = 4
556    nrows = (num_heads + ncols - 1) // ncols
557
558    fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
559    axes = axes.flatten()
560
561    for head in range(num_heads):
562        weights = attention_weights[0, head].cpu().detach().numpy()
563        ax = axes[head]
564        im = ax.imshow(weights, cmap='Blues')
565        ax.set_title(f'Head {head}')
566
567        if tokens and len(tokens) <= 10:
568            ax.set_xticks(range(len(tokens)))
569            ax.set_yticks(range(len(tokens)))
570            ax.set_xticklabels(tokens, rotation=45, ha='right')
571            ax.set_yticklabels(tokens)
572        else:
573            ax.set_xlabel('Key')
574            ax.set_ylabel('Query')
575
576    for i in range(num_heads, len(axes)):
577        axes[i].axis('off')
578
579    plt.tight_layout()
580    plt.savefig(filename, dpi=150)
581    plt.close()
582    print(f"Attention visualization saved to {filename}")
583
584
585# Generate sample attention
586mha_vis = MultiHeadAttention(d_model=64, num_heads=8)
587x_vis = torch.randn(1, 10, 64)
588_, attn_weights = mha_vis(x_vis, x_vis, x_vis, return_attention=True)
589tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat', 'and', 'then', 'left', '.']
590visualize_attention_weights(attn_weights, tokens, 'attention_heads.png')
591
592
593# ============================================
594# Summary
595# ============================================
596print("\n" + "=" * 60)
597print("Efficient Attention Summary")
598print("=" * 60)
599
600summary = """
601Key Concepts:
6021. Standard Attention: O(n^2) time and space
6032. Flash Attention: O(n^2) time, O(n) space via tiling
6043. Sparse Attention: O(n * k) where k << n
605
606Attention Patterns:
607- Local: Sliding window attention
608- Strided: Fixed-stride sparse pattern
609- BigBird: Local + global + random
610- Causal: Autoregressive masking
611
612Position Encodings:
613- Sinusoidal: Fixed, no parameters
614- Learned: Trainable embeddings
615- RoPE: Rotary, good relative position
616- ALiBi: Linear bias, best extrapolation
617
618PyTorch Tips:
6191. Use F.scaled_dot_product_attention (PyTorch 2.0+)
6202. Enable Flash Attention when possible
6213. Use is_causal=True for autoregressive
622
623Memory Comparison (seq_len=4096, heads=12):
624- Standard: ~1.5 GB (attention matrix)
625- Flash: ~0.1 GB (no full matrix storage)
626
627Output Files:
628- attention_masks.png: Sparse attention patterns
629- position_encodings.png: PE visualizations
630- attention_heads.png: Multi-head attention weights
631"""
632print(summary)
633print("=" * 60)