12_attention_math.py

Download
python 420 lines 12.7 KB
  1"""
  2Attention Mechanism Mathematics
  3
  4This script demonstrates:
  5- Scaled dot-product attention from scratch
  6- Multi-head attention implementation
  7- Positional encoding (sinusoidal)
  8- Visualization of attention weights
  9- Comparison with PyTorch nn.MultiheadAttention
 10
 11Attention mechanism:
 12    Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
 13
 14Where:
 15- Q: Query matrix
 16- K: Key matrix
 17- V: Value matrix
 18- d_k: Dimension of keys (for scaling)
 19"""
 20
 21import numpy as np
 22import torch
 23import torch.nn as nn
 24import torch.nn.functional as F
 25import matplotlib.pyplot as plt
 26import seaborn as sns
 27
 28
 29def scaled_dot_product_attention(Q, K, V, mask=None):
 30    """
 31    Compute scaled dot-product attention.
 32
 33    Args:
 34        Q: Query matrix (batch, seq_len_q, d_k)
 35        K: Key matrix (batch, seq_len_k, d_k)
 36        V: Value matrix (batch, seq_len_k, d_v)
 37        mask: Optional mask (batch, seq_len_q, seq_len_k)
 38
 39    Returns:
 40        output: Attention output (batch, seq_len_q, d_v)
 41        attention_weights: Attention weights (batch, seq_len_q, seq_len_k)
 42    """
 43    d_k = Q.shape[-1]
 44
 45    # Compute attention scores: Q @ K^T / sqrt(d_k)
 46    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
 47
 48    # Apply mask (if provided)
 49    if mask is not None:
 50        scores = scores.masked_fill(mask == 0, -1e9)
 51
 52    # Apply softmax to get attention weights
 53    attention_weights = F.softmax(scores, dim=-1)
 54
 55    # Compute weighted sum of values
 56    output = torch.matmul(attention_weights, V)
 57
 58    return output, attention_weights
 59
 60
 61class MultiHeadAttention(nn.Module):
 62    """
 63    Multi-head attention layer.
 64
 65    Multi-head attention allows the model to attend to information
 66    from different representation subspaces.
 67    """
 68
 69    def __init__(self, d_model, num_heads):
 70        super().__init__()
 71        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
 72
 73        self.d_model = d_model
 74        self.num_heads = num_heads
 75        self.d_k = d_model // num_heads
 76
 77        # Linear projections for Q, K, V
 78        self.W_q = nn.Linear(d_model, d_model)
 79        self.W_k = nn.Linear(d_model, d_model)
 80        self.W_v = nn.Linear(d_model, d_model)
 81
 82        # Output projection
 83        self.W_o = nn.Linear(d_model, d_model)
 84
 85    def split_heads(self, x):
 86        """
 87        Split last dimension into (num_heads, d_k).
 88        Transpose to (batch, num_heads, seq_len, d_k)
 89        """
 90        batch_size, seq_len, d_model = x.shape
 91        x = x.reshape(batch_size, seq_len, self.num_heads, self.d_k)
 92        return x.transpose(1, 2)
 93
 94    def combine_heads(self, x):
 95        """
 96        Inverse of split_heads.
 97        Transpose and reshape back to (batch, seq_len, d_model)
 98        """
 99        batch_size, num_heads, seq_len, d_k = x.shape
100        x = x.transpose(1, 2)
101        return x.reshape(batch_size, seq_len, self.d_model)
102
103    def forward(self, Q, K, V, mask=None):
104        """
105        Forward pass.
106
107        Args:
108            Q, K, V: Input tensors (batch, seq_len, d_model)
109            mask: Optional mask
110
111        Returns:
112            output: Attention output
113            attention_weights: Attention weights (for visualization)
114        """
115        # Linear projections
116        Q = self.W_q(Q)
117        K = self.W_k(K)
118        V = self.W_v(V)
119
120        # Split into multiple heads
121        Q = self.split_heads(Q)  # (batch, num_heads, seq_len_q, d_k)
122        K = self.split_heads(K)  # (batch, num_heads, seq_len_k, d_k)
123        V = self.split_heads(V)  # (batch, num_heads, seq_len_v, d_k)
124
125        # Apply attention
126        output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
127
128        # Combine heads
129        output = self.combine_heads(output)  # (batch, seq_len_q, d_model)
130
131        # Final linear projection
132        output = self.W_o(output)
133
134        return output, attention_weights
135
136
137def positional_encoding(seq_len, d_model):
138    """
139    Generate sinusoidal positional encoding.
140
141    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
142    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
143
144    Args:
145        seq_len: Sequence length
146        d_model: Model dimension
147
148    Returns:
149        Positional encoding matrix (seq_len, d_model)
150    """
151    position = np.arange(seq_len)[:, np.newaxis]
152    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
153
154    pe = np.zeros((seq_len, d_model))
155    pe[:, 0::2] = np.sin(position * div_term)
156    pe[:, 1::2] = np.cos(position * div_term)
157
158    return torch.FloatTensor(pe)
159
160
161def visualize_positional_encoding():
162    """
163    Visualize sinusoidal positional encoding.
164    """
165    print("=== Positional Encoding ===\n")
166
167    seq_len = 100
168    d_model = 128
169
170    pe = positional_encoding(seq_len, d_model)
171    print(f"Positional encoding shape: {pe.shape}\n")
172
173    # Plot
174    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
175
176    # Heatmap
177    im = axes[0].imshow(pe.numpy().T, cmap='RdBu', aspect='auto', interpolation='nearest')
178    axes[0].set_title('Positional Encoding Heatmap')
179    axes[0].set_xlabel('Position')
180    axes[0].set_ylabel('Dimension')
181    plt.colorbar(im, ax=axes[0])
182
183    # Sample dimensions
184    axes[1].plot(pe[:, 4].numpy(), label='dim 4')
185    axes[1].plot(pe[:, 5].numpy(), label='dim 5')
186    axes[1].plot(pe[:, 16].numpy(), label='dim 16')
187    axes[1].plot(pe[:, 32].numpy(), label='dim 32')
188    axes[1].set_title('Positional Encoding for Selected Dimensions')
189    axes[1].set_xlabel('Position')
190    axes[1].set_ylabel('Encoding Value')
191    axes[1].legend()
192    axes[1].grid(True, alpha=0.3)
193
194    plt.tight_layout()
195    plt.savefig('/tmp/positional_encoding.png', dpi=150, bbox_inches='tight')
196    print("Plot saved to /tmp/positional_encoding.png\n")
197    plt.close()
198
199
200def visualize_attention_weights():
201    """
202    Visualize attention weights for a simple example.
203    """
204    print("=== Attention Weights Visualization ===\n")
205
206    # Create simple input
207    seq_len = 10
208    d_model = 64
209    batch_size = 1
210
211    # Random input
212    X = torch.randn(batch_size, seq_len, d_model)
213
214    # Add positional encoding
215    pe = positional_encoding(seq_len, d_model)
216    X = X + pe.unsqueeze(0)
217
218    # Self-attention
219    Q = K = V = X
220
221    # Compute attention
222    output, attn_weights = scaled_dot_product_attention(Q, K, V)
223
224    print(f"Input shape: {X.shape}")
225    print(f"Attention weights shape: {attn_weights.shape}")
226    print(f"Output shape: {output.shape}\n")
227
228    # Visualize attention matrix
229    fig, ax = plt.subplots(figsize=(8, 7))
230    sns.heatmap(attn_weights[0].detach().numpy(), cmap='viridis',
231                xticklabels=range(seq_len), yticklabels=range(seq_len),
232                cbar_kws={'label': 'Attention Weight'}, ax=ax)
233    ax.set_title('Self-Attention Weights')
234    ax.set_xlabel('Key Position')
235    ax.set_ylabel('Query Position')
236
237    plt.tight_layout()
238    plt.savefig('/tmp/attention_weights.png', dpi=150, bbox_inches='tight')
239    print("Plot saved to /tmp/attention_weights.png\n")
240    plt.close()
241
242
243def compare_with_pytorch():
244    """
245    Compare custom implementation with PyTorch nn.MultiheadAttention.
246    """
247    print("=== Comparison with PyTorch nn.MultiheadAttention ===\n")
248
249    batch_size = 2
250    seq_len = 8
251    d_model = 64
252    num_heads = 8
253
254    # Random input
255    X = torch.randn(batch_size, seq_len, d_model)
256
257    # Custom implementation
258    custom_mha = MultiHeadAttention(d_model, num_heads)
259    output_custom, attn_custom = custom_mha(X, X, X)
260
261    # PyTorch implementation
262    # Note: PyTorch expects (seq_len, batch, d_model), so we transpose
263    pytorch_mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
264
265    # Copy weights from custom to PyTorch for fair comparison
266    # (In practice, they would be trained, so outputs would differ)
267    with torch.no_grad():
268        pytorch_mha.in_proj_weight.copy_(torch.cat([
269            custom_mha.W_q.weight,
270            custom_mha.W_k.weight,
271            custom_mha.W_v.weight
272        ], dim=0))
273        pytorch_mha.in_proj_bias.copy_(torch.cat([
274            custom_mha.W_q.bias,
275            custom_mha.W_k.bias,
276            custom_mha.W_v.bias
277        ], dim=0))
278        pytorch_mha.out_proj.weight.copy_(custom_mha.W_o.weight)
279        pytorch_mha.out_proj.bias.copy_(custom_mha.W_o.bias)
280
281    output_pytorch, attn_pytorch = pytorch_mha(X, X, X, average_attn_weights=False)
282
283    print(f"Custom output shape: {output_custom.shape}")
284    print(f"PyTorch output shape: {output_pytorch.shape}")
285    print(f"Outputs close: {torch.allclose(output_custom, output_pytorch, atol=1e-5)}\n")
286
287    print(f"Custom attention shape: {attn_custom.shape}")
288    print(f"PyTorch attention shape: {attn_pytorch.shape}\n")
289
290    # The outputs should be very close if weights are identical
291    diff = torch.abs(output_custom - output_pytorch).max().item()
292    print(f"Max absolute difference: {diff:.6f}\n")
293
294
295def demonstrate_causal_masking():
296    """
297    Demonstrate causal (autoregressive) masking for decoder-style attention.
298    """
299    print("=== Causal Masking (Decoder Attention) ===\n")
300
301    seq_len = 6
302    d_model = 32
303    batch_size = 1
304
305    X = torch.randn(batch_size, seq_len, d_model)
306    Q = K = V = X
307
308    # Create causal mask (lower triangular)
309    causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
310    print("Causal mask (lower triangular):")
311    print(causal_mask[0].numpy().astype(int))
312    print()
313
314    # Apply attention with mask
315    output, attn_weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
316
317    print(f"Attention weights with causal mask:")
318    print(f"Shape: {attn_weights.shape}\n")
319
320    # Visualize masked attention
321    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
322
323    # Without mask
324    _, attn_no_mask = scaled_dot_product_attention(Q, K, V, mask=None)
325    sns.heatmap(attn_no_mask[0].detach().numpy(), cmap='viridis',
326                xticklabels=range(seq_len), yticklabels=range(seq_len),
327                cbar_kws={'label': 'Attention Weight'}, ax=axes[0])
328    axes[0].set_title('Attention Without Mask')
329    axes[0].set_xlabel('Key Position')
330    axes[0].set_ylabel('Query Position')
331
332    # With causal mask
333    sns.heatmap(attn_weights[0].detach().numpy(), cmap='viridis',
334                xticklabels=range(seq_len), yticklabels=range(seq_len),
335                cbar_kws={'label': 'Attention Weight'}, ax=axes[1])
336    axes[1].set_title('Attention With Causal Mask')
337    axes[1].set_xlabel('Key Position')
338    axes[1].set_ylabel('Query Position')
339
340    plt.tight_layout()
341    plt.savefig('/tmp/causal_masking.png', dpi=150, bbox_inches='tight')
342    print("Plot saved to /tmp/causal_masking.png\n")
343    plt.close()
344
345
346def demonstrate_cross_attention():
347    """
348    Demonstrate cross-attention (encoder-decoder attention).
349    """
350    print("=== Cross-Attention (Encoder-Decoder) ===\n")
351
352    batch_size = 1
353    encoder_seq_len = 8
354    decoder_seq_len = 5
355    d_model = 64
356
357    # Encoder output (keys and values)
358    encoder_output = torch.randn(batch_size, encoder_seq_len, d_model)
359
360    # Decoder hidden states (queries)
361    decoder_hidden = torch.randn(batch_size, decoder_seq_len, d_model)
362
363    # Cross-attention: decoder queries attend to encoder keys/values
364    Q = decoder_hidden
365    K = V = encoder_output
366
367    output, attn_weights = scaled_dot_product_attention(Q, K, V)
368
369    print(f"Encoder sequence length: {encoder_seq_len}")
370    print(f"Decoder sequence length: {decoder_seq_len}")
371    print(f"Cross-attention weights shape: {attn_weights.shape}")
372    print(f"(decoder_seq_len x encoder_seq_len)\n")
373
374    # Visualize
375    fig, ax = plt.subplots(figsize=(9, 6))
376    sns.heatmap(attn_weights[0].detach().numpy(), cmap='viridis',
377                xticklabels=range(encoder_seq_len),
378                yticklabels=range(decoder_seq_len),
379                cbar_kws={'label': 'Attention Weight'}, ax=ax)
380    ax.set_title('Cross-Attention Weights (Decoder → Encoder)')
381    ax.set_xlabel('Encoder Position')
382    ax.set_ylabel('Decoder Position')
383
384    plt.tight_layout()
385    plt.savefig('/tmp/cross_attention.png', dpi=150, bbox_inches='tight')
386    print("Plot saved to /tmp/cross_attention.png\n")
387    plt.close()
388
389
390if __name__ == "__main__":
391    print("=" * 60)
392    print("Attention Mechanism Mathematics")
393    print("=" * 60)
394    print()
395
396    # Set random seed
397    torch.manual_seed(42)
398    np.random.seed(42)
399
400    # Run demonstrations
401    visualize_positional_encoding()
402
403    visualize_attention_weights()
404
405    compare_with_pytorch()
406
407    demonstrate_causal_masking()
408
409    demonstrate_cross_attention()
410
411    print("=" * 60)
412    print("Summary:")
413    print("- Attention allows dynamic weighting of input elements")
414    print("- Scaled dot-product prevents gradient issues with large d_k")
415    print("- Multi-head attention captures different representation subspaces")
416    print("- Positional encoding injects sequence order information")
417    print("- Causal masking enables autoregressive generation")
418    print("- Cross-attention connects encoder and decoder")
419    print("=" * 60)