1"""
2Attention Mechanism Mathematics
3
4This script demonstrates:
5- Scaled dot-product attention from scratch
6- Multi-head attention implementation
7- Positional encoding (sinusoidal)
8- Visualization of attention weights
9- Comparison with PyTorch nn.MultiheadAttention
10
11Attention mechanism:
12 Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
13
14Where:
15- Q: Query matrix
16- K: Key matrix
17- V: Value matrix
18- d_k: Dimension of keys (for scaling)
19"""
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25import matplotlib.pyplot as plt
26import seaborn as sns
27
28
29def scaled_dot_product_attention(Q, K, V, mask=None):
30 """
31 Compute scaled dot-product attention.
32
33 Args:
34 Q: Query matrix (batch, seq_len_q, d_k)
35 K: Key matrix (batch, seq_len_k, d_k)
36 V: Value matrix (batch, seq_len_k, d_v)
37 mask: Optional mask (batch, seq_len_q, seq_len_k)
38
39 Returns:
40 output: Attention output (batch, seq_len_q, d_v)
41 attention_weights: Attention weights (batch, seq_len_q, seq_len_k)
42 """
43 d_k = Q.shape[-1]
44
45 # Compute attention scores: Q @ K^T / sqrt(d_k)
46 scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
47
48 # Apply mask (if provided)
49 if mask is not None:
50 scores = scores.masked_fill(mask == 0, -1e9)
51
52 # Apply softmax to get attention weights
53 attention_weights = F.softmax(scores, dim=-1)
54
55 # Compute weighted sum of values
56 output = torch.matmul(attention_weights, V)
57
58 return output, attention_weights
59
60
61class MultiHeadAttention(nn.Module):
62 """
63 Multi-head attention layer.
64
65 Multi-head attention allows the model to attend to information
66 from different representation subspaces.
67 """
68
69 def __init__(self, d_model, num_heads):
70 super().__init__()
71 assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
72
73 self.d_model = d_model
74 self.num_heads = num_heads
75 self.d_k = d_model // num_heads
76
77 # Linear projections for Q, K, V
78 self.W_q = nn.Linear(d_model, d_model)
79 self.W_k = nn.Linear(d_model, d_model)
80 self.W_v = nn.Linear(d_model, d_model)
81
82 # Output projection
83 self.W_o = nn.Linear(d_model, d_model)
84
85 def split_heads(self, x):
86 """
87 Split last dimension into (num_heads, d_k).
88 Transpose to (batch, num_heads, seq_len, d_k)
89 """
90 batch_size, seq_len, d_model = x.shape
91 x = x.reshape(batch_size, seq_len, self.num_heads, self.d_k)
92 return x.transpose(1, 2)
93
94 def combine_heads(self, x):
95 """
96 Inverse of split_heads.
97 Transpose and reshape back to (batch, seq_len, d_model)
98 """
99 batch_size, num_heads, seq_len, d_k = x.shape
100 x = x.transpose(1, 2)
101 return x.reshape(batch_size, seq_len, self.d_model)
102
103 def forward(self, Q, K, V, mask=None):
104 """
105 Forward pass.
106
107 Args:
108 Q, K, V: Input tensors (batch, seq_len, d_model)
109 mask: Optional mask
110
111 Returns:
112 output: Attention output
113 attention_weights: Attention weights (for visualization)
114 """
115 # Linear projections
116 Q = self.W_q(Q)
117 K = self.W_k(K)
118 V = self.W_v(V)
119
120 # Split into multiple heads
121 Q = self.split_heads(Q) # (batch, num_heads, seq_len_q, d_k)
122 K = self.split_heads(K) # (batch, num_heads, seq_len_k, d_k)
123 V = self.split_heads(V) # (batch, num_heads, seq_len_v, d_k)
124
125 # Apply attention
126 output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
127
128 # Combine heads
129 output = self.combine_heads(output) # (batch, seq_len_q, d_model)
130
131 # Final linear projection
132 output = self.W_o(output)
133
134 return output, attention_weights
135
136
137def positional_encoding(seq_len, d_model):
138 """
139 Generate sinusoidal positional encoding.
140
141 PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
142 PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
143
144 Args:
145 seq_len: Sequence length
146 d_model: Model dimension
147
148 Returns:
149 Positional encoding matrix (seq_len, d_model)
150 """
151 position = np.arange(seq_len)[:, np.newaxis]
152 div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
153
154 pe = np.zeros((seq_len, d_model))
155 pe[:, 0::2] = np.sin(position * div_term)
156 pe[:, 1::2] = np.cos(position * div_term)
157
158 return torch.FloatTensor(pe)
159
160
161def visualize_positional_encoding():
162 """
163 Visualize sinusoidal positional encoding.
164 """
165 print("=== Positional Encoding ===\n")
166
167 seq_len = 100
168 d_model = 128
169
170 pe = positional_encoding(seq_len, d_model)
171 print(f"Positional encoding shape: {pe.shape}\n")
172
173 # Plot
174 fig, axes = plt.subplots(1, 2, figsize=(14, 5))
175
176 # Heatmap
177 im = axes[0].imshow(pe.numpy().T, cmap='RdBu', aspect='auto', interpolation='nearest')
178 axes[0].set_title('Positional Encoding Heatmap')
179 axes[0].set_xlabel('Position')
180 axes[0].set_ylabel('Dimension')
181 plt.colorbar(im, ax=axes[0])
182
183 # Sample dimensions
184 axes[1].plot(pe[:, 4].numpy(), label='dim 4')
185 axes[1].plot(pe[:, 5].numpy(), label='dim 5')
186 axes[1].plot(pe[:, 16].numpy(), label='dim 16')
187 axes[1].plot(pe[:, 32].numpy(), label='dim 32')
188 axes[1].set_title('Positional Encoding for Selected Dimensions')
189 axes[1].set_xlabel('Position')
190 axes[1].set_ylabel('Encoding Value')
191 axes[1].legend()
192 axes[1].grid(True, alpha=0.3)
193
194 plt.tight_layout()
195 plt.savefig('/tmp/positional_encoding.png', dpi=150, bbox_inches='tight')
196 print("Plot saved to /tmp/positional_encoding.png\n")
197 plt.close()
198
199
200def visualize_attention_weights():
201 """
202 Visualize attention weights for a simple example.
203 """
204 print("=== Attention Weights Visualization ===\n")
205
206 # Create simple input
207 seq_len = 10
208 d_model = 64
209 batch_size = 1
210
211 # Random input
212 X = torch.randn(batch_size, seq_len, d_model)
213
214 # Add positional encoding
215 pe = positional_encoding(seq_len, d_model)
216 X = X + pe.unsqueeze(0)
217
218 # Self-attention
219 Q = K = V = X
220
221 # Compute attention
222 output, attn_weights = scaled_dot_product_attention(Q, K, V)
223
224 print(f"Input shape: {X.shape}")
225 print(f"Attention weights shape: {attn_weights.shape}")
226 print(f"Output shape: {output.shape}\n")
227
228 # Visualize attention matrix
229 fig, ax = plt.subplots(figsize=(8, 7))
230 sns.heatmap(attn_weights[0].detach().numpy(), cmap='viridis',
231 xticklabels=range(seq_len), yticklabels=range(seq_len),
232 cbar_kws={'label': 'Attention Weight'}, ax=ax)
233 ax.set_title('Self-Attention Weights')
234 ax.set_xlabel('Key Position')
235 ax.set_ylabel('Query Position')
236
237 plt.tight_layout()
238 plt.savefig('/tmp/attention_weights.png', dpi=150, bbox_inches='tight')
239 print("Plot saved to /tmp/attention_weights.png\n")
240 plt.close()
241
242
243def compare_with_pytorch():
244 """
245 Compare custom implementation with PyTorch nn.MultiheadAttention.
246 """
247 print("=== Comparison with PyTorch nn.MultiheadAttention ===\n")
248
249 batch_size = 2
250 seq_len = 8
251 d_model = 64
252 num_heads = 8
253
254 # Random input
255 X = torch.randn(batch_size, seq_len, d_model)
256
257 # Custom implementation
258 custom_mha = MultiHeadAttention(d_model, num_heads)
259 output_custom, attn_custom = custom_mha(X, X, X)
260
261 # PyTorch implementation
262 # Note: PyTorch expects (seq_len, batch, d_model), so we transpose
263 pytorch_mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
264
265 # Copy weights from custom to PyTorch for fair comparison
266 # (In practice, they would be trained, so outputs would differ)
267 with torch.no_grad():
268 pytorch_mha.in_proj_weight.copy_(torch.cat([
269 custom_mha.W_q.weight,
270 custom_mha.W_k.weight,
271 custom_mha.W_v.weight
272 ], dim=0))
273 pytorch_mha.in_proj_bias.copy_(torch.cat([
274 custom_mha.W_q.bias,
275 custom_mha.W_k.bias,
276 custom_mha.W_v.bias
277 ], dim=0))
278 pytorch_mha.out_proj.weight.copy_(custom_mha.W_o.weight)
279 pytorch_mha.out_proj.bias.copy_(custom_mha.W_o.bias)
280
281 output_pytorch, attn_pytorch = pytorch_mha(X, X, X, average_attn_weights=False)
282
283 print(f"Custom output shape: {output_custom.shape}")
284 print(f"PyTorch output shape: {output_pytorch.shape}")
285 print(f"Outputs close: {torch.allclose(output_custom, output_pytorch, atol=1e-5)}\n")
286
287 print(f"Custom attention shape: {attn_custom.shape}")
288 print(f"PyTorch attention shape: {attn_pytorch.shape}\n")
289
290 # The outputs should be very close if weights are identical
291 diff = torch.abs(output_custom - output_pytorch).max().item()
292 print(f"Max absolute difference: {diff:.6f}\n")
293
294
295def demonstrate_causal_masking():
296 """
297 Demonstrate causal (autoregressive) masking for decoder-style attention.
298 """
299 print("=== Causal Masking (Decoder Attention) ===\n")
300
301 seq_len = 6
302 d_model = 32
303 batch_size = 1
304
305 X = torch.randn(batch_size, seq_len, d_model)
306 Q = K = V = X
307
308 # Create causal mask (lower triangular)
309 causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
310 print("Causal mask (lower triangular):")
311 print(causal_mask[0].numpy().astype(int))
312 print()
313
314 # Apply attention with mask
315 output, attn_weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
316
317 print(f"Attention weights with causal mask:")
318 print(f"Shape: {attn_weights.shape}\n")
319
320 # Visualize masked attention
321 fig, axes = plt.subplots(1, 2, figsize=(12, 5))
322
323 # Without mask
324 _, attn_no_mask = scaled_dot_product_attention(Q, K, V, mask=None)
325 sns.heatmap(attn_no_mask[0].detach().numpy(), cmap='viridis',
326 xticklabels=range(seq_len), yticklabels=range(seq_len),
327 cbar_kws={'label': 'Attention Weight'}, ax=axes[0])
328 axes[0].set_title('Attention Without Mask')
329 axes[0].set_xlabel('Key Position')
330 axes[0].set_ylabel('Query Position')
331
332 # With causal mask
333 sns.heatmap(attn_weights[0].detach().numpy(), cmap='viridis',
334 xticklabels=range(seq_len), yticklabels=range(seq_len),
335 cbar_kws={'label': 'Attention Weight'}, ax=axes[1])
336 axes[1].set_title('Attention With Causal Mask')
337 axes[1].set_xlabel('Key Position')
338 axes[1].set_ylabel('Query Position')
339
340 plt.tight_layout()
341 plt.savefig('/tmp/causal_masking.png', dpi=150, bbox_inches='tight')
342 print("Plot saved to /tmp/causal_masking.png\n")
343 plt.close()
344
345
346def demonstrate_cross_attention():
347 """
348 Demonstrate cross-attention (encoder-decoder attention).
349 """
350 print("=== Cross-Attention (Encoder-Decoder) ===\n")
351
352 batch_size = 1
353 encoder_seq_len = 8
354 decoder_seq_len = 5
355 d_model = 64
356
357 # Encoder output (keys and values)
358 encoder_output = torch.randn(batch_size, encoder_seq_len, d_model)
359
360 # Decoder hidden states (queries)
361 decoder_hidden = torch.randn(batch_size, decoder_seq_len, d_model)
362
363 # Cross-attention: decoder queries attend to encoder keys/values
364 Q = decoder_hidden
365 K = V = encoder_output
366
367 output, attn_weights = scaled_dot_product_attention(Q, K, V)
368
369 print(f"Encoder sequence length: {encoder_seq_len}")
370 print(f"Decoder sequence length: {decoder_seq_len}")
371 print(f"Cross-attention weights shape: {attn_weights.shape}")
372 print(f"(decoder_seq_len x encoder_seq_len)\n")
373
374 # Visualize
375 fig, ax = plt.subplots(figsize=(9, 6))
376 sns.heatmap(attn_weights[0].detach().numpy(), cmap='viridis',
377 xticklabels=range(encoder_seq_len),
378 yticklabels=range(decoder_seq_len),
379 cbar_kws={'label': 'Attention Weight'}, ax=ax)
380 ax.set_title('Cross-Attention Weights (Decoder → Encoder)')
381 ax.set_xlabel('Encoder Position')
382 ax.set_ylabel('Decoder Position')
383
384 plt.tight_layout()
385 plt.savefig('/tmp/cross_attention.png', dpi=150, bbox_inches='tight')
386 print("Plot saved to /tmp/cross_attention.png\n")
387 plt.close()
388
389
390if __name__ == "__main__":
391 print("=" * 60)
392 print("Attention Mechanism Mathematics")
393 print("=" * 60)
394 print()
395
396 # Set random seed
397 torch.manual_seed(42)
398 np.random.seed(42)
399
400 # Run demonstrations
401 visualize_positional_encoding()
402
403 visualize_attention_weights()
404
405 compare_with_pytorch()
406
407 demonstrate_causal_masking()
408
409 demonstrate_cross_attention()
410
411 print("=" * 60)
412 print("Summary:")
413 print("- Attention allows dynamic weighting of input elements")
414 print("- Scaled dot-product prevents gradient issues with large d_k")
415 print("- Multi-head attention captures different representation subspaces")
416 print("- Positional encoding injects sequence order information")
417 print("- Causal masking enables autoregressive generation")
418 print("- Cross-attention connects encoder and decoder")
419 print("=" * 60)