1"""
210. Attention๊ณผ Transformer
3
4Attention ๋ฉ์ปค๋์ฆ๊ณผ Transformer๋ฅผ PyTorch๋ก ๊ตฌํํฉ๋๋ค.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10import math
11import matplotlib.pyplot as plt
12import numpy as np
13
14print("=" * 60)
15print("PyTorch Attention & Transformer")
16print("=" * 60)
17
18device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
20
21# ============================================
22# 1. Scaled Dot-Product Attention
23# ============================================
24print("\n[1] Scaled Dot-Product Attention")
25print("-" * 40)
26
27def scaled_dot_product_attention(Q, K, V, mask=None):
28 """
29 Args:
30 Q: (batch, seq_q, d_k)
31 K: (batch, seq_k, d_k)
32 V: (batch, seq_k, d_v)
33 mask: (batch, seq_q, seq_k) or broadcastable
34 Returns:
35 output: (batch, seq_q, d_v)
36 attention_weights: (batch, seq_q, seq_k)
37 """
38 d_k = K.size(-1)
39
40 # ์ค์ฝ์ด ๊ณ์ฐ
41 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
42
43 # ๋ง์คํน
44 if mask is not None:
45 scores = scores.masked_fill(mask == 0, -1e9)
46
47 # Softmax
48 attention_weights = F.softmax(scores, dim=-1)
49
50 # ๊ฐ์ค ํฉ
51 output = torch.matmul(attention_weights, V)
52
53 return output, attention_weights
54
55# ํ
์คํธ
56batch_size = 2
57seq_len = 5
58d_k = 8
59
60Q = torch.randn(batch_size, seq_len, d_k)
61K = torch.randn(batch_size, seq_len, d_k)
62V = torch.randn(batch_size, seq_len, d_k)
63
64output, weights = scaled_dot_product_attention(Q, K, V)
65print(f"Q, K, V: ({batch_size}, {seq_len}, {d_k})")
66print(f"Output: {output.shape}")
67print(f"Attention Weights: {weights.shape}")
68print(f"Weights sum (should be 1): {weights[0, 0].sum().item():.4f}")
69
70
71# ============================================
72# 2. Multi-Head Attention
73# ============================================
74print("\n[2] Multi-Head Attention")
75print("-" * 40)
76
77class MultiHeadAttention(nn.Module):
78 def __init__(self, d_model, num_heads, dropout=0.1):
79 super().__init__()
80 assert d_model % num_heads == 0
81
82 self.d_model = d_model
83 self.num_heads = num_heads
84 self.d_k = d_model // num_heads
85
86 self.W_Q = nn.Linear(d_model, d_model)
87 self.W_K = nn.Linear(d_model, d_model)
88 self.W_V = nn.Linear(d_model, d_model)
89 self.W_O = nn.Linear(d_model, d_model)
90
91 self.dropout = nn.Dropout(dropout)
92
93 def forward(self, Q, K, V, mask=None):
94 batch_size = Q.size(0)
95
96 # ์ ํ ๋ณํ
97 Q = self.W_Q(Q)
98 K = self.W_K(K)
99 V = self.W_V(V)
100
101 # ํค๋ ๋ถํ : (batch, seq, d_model) โ (batch, heads, seq, d_k)
102 Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
103 K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
104 V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
105
106 # Attention
107 attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
108
109 # ํค๋ ๊ฒฐํฉ
110 attn_output = attn_output.transpose(1, 2).contiguous()
111 attn_output = attn_output.view(batch_size, -1, self.d_model)
112
113 # ์ถ๋ ฅ ๋ณํ
114 output = self.W_O(attn_output)
115
116 return output, attn_weights
117
118# ํ
์คํธ
119mha = MultiHeadAttention(d_model=64, num_heads=8)
120x = torch.randn(2, 10, 64)
121output, weights = mha(x, x, x)
122print(f"์
๋ ฅ: {x.shape}")
123print(f"์ถ๋ ฅ: {output.shape}")
124print(f"Attention Weights: {weights.shape}")
125
126
127# ============================================
128# 3. Positional Encoding
129# ============================================
130print("\n[3] Positional Encoding")
131print("-" * 40)
132
133class PositionalEncoding(nn.Module):
134 def __init__(self, d_model, max_len=5000, dropout=0.1):
135 super().__init__()
136 self.dropout = nn.Dropout(dropout)
137
138 # Positional Encoding ๊ณ์ฐ
139 pe = torch.zeros(max_len, d_model)
140 position = torch.arange(0, max_len).unsqueeze(1).float()
141 div_term = torch.exp(torch.arange(0, d_model, 2).float() *
142 (-math.log(10000.0) / d_model))
143
144 pe[:, 0::2] = torch.sin(position * div_term)
145 pe[:, 1::2] = torch.cos(position * div_term)
146
147 pe = pe.unsqueeze(0) # (1, max_len, d_model)
148 self.register_buffer('pe', pe)
149
150 def forward(self, x):
151 # x: (batch, seq, d_model)
152 x = x + self.pe[:, :x.size(1)]
153 return self.dropout(x)
154
155# ์๊ฐํ
156pe = PositionalEncoding(d_model=64)
157positions = pe.pe[0, :50, :].numpy()
158
159plt.figure(figsize=(12, 4))
160plt.imshow(positions.T, aspect='auto', cmap='RdBu')
161plt.xlabel('Position')
162plt.ylabel('Dimension')
163plt.title('Positional Encoding')
164plt.colorbar()
165plt.savefig('positional_encoding.png', dpi=100)
166plt.close()
167print("๊ทธ๋ํ ์ ์ฅ: positional_encoding.png")
168
169
170# ============================================
171# 4. Feed Forward Network
172# ============================================
173print("\n[4] Feed Forward Network")
174print("-" * 40)
175
176class FeedForward(nn.Module):
177 def __init__(self, d_model, d_ff, dropout=0.1):
178 super().__init__()
179 self.linear1 = nn.Linear(d_model, d_ff)
180 self.linear2 = nn.Linear(d_ff, d_model)
181 self.dropout = nn.Dropout(dropout)
182
183 def forward(self, x):
184 return self.linear2(self.dropout(F.relu(self.linear1(x))))
185
186ff = FeedForward(d_model=64, d_ff=256)
187x = torch.randn(2, 10, 64)
188output = ff(x)
189print(f"FFN ์
๋ ฅ: {x.shape} โ ์ถ๋ ฅ: {output.shape}")
190
191
192# ============================================
193# 5. Transformer Encoder Layer
194# ============================================
195print("\n[5] Transformer Encoder Layer")
196print("-" * 40)
197
198class TransformerEncoderLayer(nn.Module):
199 def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
200 super().__init__()
201 self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
202 self.ffn = FeedForward(d_model, d_ff, dropout)
203 self.norm1 = nn.LayerNorm(d_model)
204 self.norm2 = nn.LayerNorm(d_model)
205 self.dropout = nn.Dropout(dropout)
206
207 def forward(self, x, mask=None):
208 # Self-Attention + Residual + LayerNorm
209 attn_out, _ = self.self_attn(x, x, x, mask)
210 x = self.norm1(x + self.dropout(attn_out))
211
212 # FFN + Residual + LayerNorm
213 ffn_out = self.ffn(x)
214 x = self.norm2(x + self.dropout(ffn_out))
215
216 return x
217
218encoder_layer = TransformerEncoderLayer(d_model=64, num_heads=8, d_ff=256)
219x = torch.randn(2, 10, 64)
220output = encoder_layer(x)
221print(f"์ธ์ฝ๋ ์ธต ์
๋ ฅ: {x.shape} โ ์ถ๋ ฅ: {output.shape}")
222
223
224# ============================================
225# 6. Full Transformer Encoder
226# ============================================
227print("\n[6] Transformer Encoder")
228print("-" * 40)
229
230class TransformerEncoder(nn.Module):
231 def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
232 super().__init__()
233 self.layers = nn.ModuleList([
234 TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
235 for _ in range(num_layers)
236 ])
237 self.norm = nn.LayerNorm(d_model)
238
239 def forward(self, x, mask=None):
240 for layer in self.layers:
241 x = layer(x, mask)
242 return self.norm(x)
243
244encoder = TransformerEncoder(num_layers=6, d_model=64, num_heads=8, d_ff=256)
245x = torch.randn(2, 10, 64)
246output = encoder(x)
247print(f"Transformer Encoder ์ถ๋ ฅ: {output.shape}")
248print(f"ํ๋ผ๋ฏธํฐ ์: {sum(p.numel() for p in encoder.parameters()):,}")
249
250
251# ============================================
252# 7. Transformer ๋ถ๋ฅ๊ธฐ
253# ============================================
254print("\n[7] Transformer ๋ถ๋ฅ๊ธฐ")
255print("-" * 40)
256
257class TransformerClassifier(nn.Module):
258 def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes,
259 d_ff=None, max_len=512, dropout=0.1):
260 super().__init__()
261 d_ff = d_ff or d_model * 4
262
263 self.embedding = nn.Embedding(vocab_size, d_model)
264 self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
265 self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
266 self.fc = nn.Linear(d_model, num_classes)
267
268 self._init_weights()
269
270 def _init_weights(self):
271 for p in self.parameters():
272 if p.dim() > 1:
273 nn.init.xavier_uniform_(p)
274
275 def forward(self, x, mask=None):
276 # x: (batch, seq) - ํ ํฐ ์ธ๋ฑ์ค
277 x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
278 x = self.pos_encoding(x)
279 x = self.encoder(x, mask)
280
281 # ํ๊ท ํ๋ง ๋๋ [CLS] ํ ํฐ
282 x = x.mean(dim=1)
283
284 return self.fc(x)
285
286model = TransformerClassifier(
287 vocab_size=10000,
288 d_model=128,
289 num_heads=8,
290 num_layers=4,
291 num_classes=5
292)
293
294x = torch.randint(0, 10000, (4, 32))
295output = model(x)
296print(f"๋ถ๋ฅ๊ธฐ ์
๋ ฅ: {x.shape}")
297print(f"๋ถ๋ฅ๊ธฐ ์ถ๋ ฅ: {output.shape}")
298print(f"ํ๋ผ๋ฏธํฐ ์: {sum(p.numel() for p in model.parameters()):,}")
299
300
301# ============================================
302# 8. PyTorch ๋ด์ฅ Transformer
303# ============================================
304print("\n[8] PyTorch ๋ด์ฅ Transformer")
305print("-" * 40)
306
307# nn.TransformerEncoder ์ฌ์ฉ
308pytorch_encoder_layer = nn.TransformerEncoderLayer(
309 d_model=64,
310 nhead=8,
311 dim_feedforward=256,
312 dropout=0.1,
313 batch_first=True # PyTorch 1.9+
314)
315
316pytorch_encoder = nn.TransformerEncoder(pytorch_encoder_layer, num_layers=6)
317
318x = torch.randn(2, 10, 64)
319output = pytorch_encoder(x)
320print(f"PyTorch Transformer ์ถ๋ ฅ: {output.shape}")
321
322
323# ============================================
324# 9. Attention ์๊ฐํ
325# ============================================
326print("\n[9] Attention ์๊ฐํ")
327print("-" * 40)
328
329def visualize_attention(attention_weights, tokens=None):
330 """Attention ๊ฐ์ค์น ์๊ฐํ"""
331 weights = attention_weights[0, 0].detach().numpy() # ์ฒซ ๋ฐฐ์น, ์ฒซ ํค๋
332
333 plt.figure(figsize=(8, 6))
334 plt.imshow(weights, cmap='Blues')
335 plt.colorbar()
336
337 if tokens:
338 plt.xticks(range(len(tokens)), tokens, rotation=45)
339 plt.yticks(range(len(tokens)), tokens)
340
341 plt.xlabel('Key')
342 plt.ylabel('Query')
343 plt.title('Attention Weights')
344 plt.tight_layout()
345 plt.savefig('attention_visualization.png', dpi=100)
346 plt.close()
347 print("๊ทธ๋ํ ์ ์ฅ: attention_visualization.png")
348
349# ์์ attention ์๊ฐํ
350mha = MultiHeadAttention(d_model=64, num_heads=8)
351x = torch.randn(1, 6, 64)
352_, weights = mha(x, x, x)
353visualize_attention(weights, ['The', 'cat', 'sat', 'on', 'mat', '.'])
354
355
356# ============================================
357# 10. Causal Mask (๋์ฝ๋์ฉ)
358# ============================================
359print("\n[10] Causal Mask")
360print("-" * 40)
361
362def create_causal_mask(size):
363 """๋ฏธ๋ ํ ํฐ์ ๋ณผ ์ ์๊ฒ ํ๋ ๋ง์คํฌ"""
364 mask = torch.triu(torch.ones(size, size), diagonal=1)
365 return mask == 0 # True: ์ฐธ์กฐ ๊ฐ๋ฅ, False: ๋ง์คํน
366
367mask = create_causal_mask(5)
368print(f"Causal Mask (5x5):\n{mask.int()}")
369
370
371# ============================================
372# ์ ๋ฆฌ
373# ============================================
374print("\n" + "=" * 60)
375print("Attention & Transformer ์ ๋ฆฌ")
376print("=" * 60)
377
378summary = """
379Scaled Dot-Product Attention:
380 scores = Q @ K.T / sqrt(d_k)
381 weights = softmax(scores)
382 output = weights @ V
383
384Multi-Head Attention:
385 - ์ฌ๋ฌ ํค๋๊ฐ ๋ค๋ฅธ ๊ด๊ณ ํ์ต
386 - ๊ฐ ํค๋: d_k = d_model / num_heads
387
388Transformer Encoder:
389 - Self-Attention + FFN
390 - Residual + LayerNorm
391 - Positional Encoding
392
393PyTorch ๋ด์ฅ:
394 encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
395 encoder = nn.TransformerEncoder(encoder_layer, num_layers)
396
397ํต์ฌ ํ์ดํผํ๋ผ๋ฏธํฐ:
398 - d_model: ๋ชจ๋ธ ์ฐจ์ (512)
399 - num_heads: ํค๋ ์ (8)
400 - d_ff: FFN ์ฐจ์ (2048)
401 - num_layers: ์ธต ์ (6)
402"""
403print(summary)
404print("=" * 60)