10_attention_transformer.py

Download
python 405 lines 11.6 KB
  1"""
  210. Attention๊ณผ Transformer
  3
  4Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜๊ณผ Transformer๋ฅผ PyTorch๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10import math
 11import matplotlib.pyplot as plt
 12import numpy as np
 13
 14print("=" * 60)
 15print("PyTorch Attention & Transformer")
 16print("=" * 60)
 17
 18device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 19
 20
 21# ============================================
 22# 1. Scaled Dot-Product Attention
 23# ============================================
 24print("\n[1] Scaled Dot-Product Attention")
 25print("-" * 40)
 26
 27def scaled_dot_product_attention(Q, K, V, mask=None):
 28    """
 29    Args:
 30        Q: (batch, seq_q, d_k)
 31        K: (batch, seq_k, d_k)
 32        V: (batch, seq_k, d_v)
 33        mask: (batch, seq_q, seq_k) or broadcastable
 34    Returns:
 35        output: (batch, seq_q, d_v)
 36        attention_weights: (batch, seq_q, seq_k)
 37    """
 38    d_k = K.size(-1)
 39
 40    # ์Šค์ฝ”์–ด ๊ณ„์‚ฐ
 41    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
 42
 43    # ๋งˆ์Šคํ‚น
 44    if mask is not None:
 45        scores = scores.masked_fill(mask == 0, -1e9)
 46
 47    # Softmax
 48    attention_weights = F.softmax(scores, dim=-1)
 49
 50    # ๊ฐ€์ค‘ ํ•ฉ
 51    output = torch.matmul(attention_weights, V)
 52
 53    return output, attention_weights
 54
 55# ํ…Œ์ŠคํŠธ
 56batch_size = 2
 57seq_len = 5
 58d_k = 8
 59
 60Q = torch.randn(batch_size, seq_len, d_k)
 61K = torch.randn(batch_size, seq_len, d_k)
 62V = torch.randn(batch_size, seq_len, d_k)
 63
 64output, weights = scaled_dot_product_attention(Q, K, V)
 65print(f"Q, K, V: ({batch_size}, {seq_len}, {d_k})")
 66print(f"Output: {output.shape}")
 67print(f"Attention Weights: {weights.shape}")
 68print(f"Weights sum (should be 1): {weights[0, 0].sum().item():.4f}")
 69
 70
 71# ============================================
 72# 2. Multi-Head Attention
 73# ============================================
 74print("\n[2] Multi-Head Attention")
 75print("-" * 40)
 76
 77class MultiHeadAttention(nn.Module):
 78    def __init__(self, d_model, num_heads, dropout=0.1):
 79        super().__init__()
 80        assert d_model % num_heads == 0
 81
 82        self.d_model = d_model
 83        self.num_heads = num_heads
 84        self.d_k = d_model // num_heads
 85
 86        self.W_Q = nn.Linear(d_model, d_model)
 87        self.W_K = nn.Linear(d_model, d_model)
 88        self.W_V = nn.Linear(d_model, d_model)
 89        self.W_O = nn.Linear(d_model, d_model)
 90
 91        self.dropout = nn.Dropout(dropout)
 92
 93    def forward(self, Q, K, V, mask=None):
 94        batch_size = Q.size(0)
 95
 96        # ์„ ํ˜• ๋ณ€ํ™˜
 97        Q = self.W_Q(Q)
 98        K = self.W_K(K)
 99        V = self.W_V(V)
100
101        # ํ—ค๋“œ ๋ถ„ํ• : (batch, seq, d_model) โ†’ (batch, heads, seq, d_k)
102        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
103        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
104        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
105
106        # Attention
107        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
108
109        # ํ—ค๋“œ ๊ฒฐํ•ฉ
110        attn_output = attn_output.transpose(1, 2).contiguous()
111        attn_output = attn_output.view(batch_size, -1, self.d_model)
112
113        # ์ถœ๋ ฅ ๋ณ€ํ™˜
114        output = self.W_O(attn_output)
115
116        return output, attn_weights
117
118# ํ…Œ์ŠคํŠธ
119mha = MultiHeadAttention(d_model=64, num_heads=8)
120x = torch.randn(2, 10, 64)
121output, weights = mha(x, x, x)
122print(f"์ž…๋ ฅ: {x.shape}")
123print(f"์ถœ๋ ฅ: {output.shape}")
124print(f"Attention Weights: {weights.shape}")
125
126
127# ============================================
128# 3. Positional Encoding
129# ============================================
130print("\n[3] Positional Encoding")
131print("-" * 40)
132
133class PositionalEncoding(nn.Module):
134    def __init__(self, d_model, max_len=5000, dropout=0.1):
135        super().__init__()
136        self.dropout = nn.Dropout(dropout)
137
138        # Positional Encoding ๊ณ„์‚ฐ
139        pe = torch.zeros(max_len, d_model)
140        position = torch.arange(0, max_len).unsqueeze(1).float()
141        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
142                            (-math.log(10000.0) / d_model))
143
144        pe[:, 0::2] = torch.sin(position * div_term)
145        pe[:, 1::2] = torch.cos(position * div_term)
146
147        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
148        self.register_buffer('pe', pe)
149
150    def forward(self, x):
151        # x: (batch, seq, d_model)
152        x = x + self.pe[:, :x.size(1)]
153        return self.dropout(x)
154
155# ์‹œ๊ฐํ™”
156pe = PositionalEncoding(d_model=64)
157positions = pe.pe[0, :50, :].numpy()
158
159plt.figure(figsize=(12, 4))
160plt.imshow(positions.T, aspect='auto', cmap='RdBu')
161plt.xlabel('Position')
162plt.ylabel('Dimension')
163plt.title('Positional Encoding')
164plt.colorbar()
165plt.savefig('positional_encoding.png', dpi=100)
166plt.close()
167print("๊ทธ๋ž˜ํ”„ ์ €์žฅ: positional_encoding.png")
168
169
170# ============================================
171# 4. Feed Forward Network
172# ============================================
173print("\n[4] Feed Forward Network")
174print("-" * 40)
175
176class FeedForward(nn.Module):
177    def __init__(self, d_model, d_ff, dropout=0.1):
178        super().__init__()
179        self.linear1 = nn.Linear(d_model, d_ff)
180        self.linear2 = nn.Linear(d_ff, d_model)
181        self.dropout = nn.Dropout(dropout)
182
183    def forward(self, x):
184        return self.linear2(self.dropout(F.relu(self.linear1(x))))
185
186ff = FeedForward(d_model=64, d_ff=256)
187x = torch.randn(2, 10, 64)
188output = ff(x)
189print(f"FFN ์ž…๋ ฅ: {x.shape} โ†’ ์ถœ๋ ฅ: {output.shape}")
190
191
192# ============================================
193# 5. Transformer Encoder Layer
194# ============================================
195print("\n[5] Transformer Encoder Layer")
196print("-" * 40)
197
198class TransformerEncoderLayer(nn.Module):
199    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
200        super().__init__()
201        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
202        self.ffn = FeedForward(d_model, d_ff, dropout)
203        self.norm1 = nn.LayerNorm(d_model)
204        self.norm2 = nn.LayerNorm(d_model)
205        self.dropout = nn.Dropout(dropout)
206
207    def forward(self, x, mask=None):
208        # Self-Attention + Residual + LayerNorm
209        attn_out, _ = self.self_attn(x, x, x, mask)
210        x = self.norm1(x + self.dropout(attn_out))
211
212        # FFN + Residual + LayerNorm
213        ffn_out = self.ffn(x)
214        x = self.norm2(x + self.dropout(ffn_out))
215
216        return x
217
218encoder_layer = TransformerEncoderLayer(d_model=64, num_heads=8, d_ff=256)
219x = torch.randn(2, 10, 64)
220output = encoder_layer(x)
221print(f"์ธ์ฝ”๋” ์ธต ์ž…๋ ฅ: {x.shape} โ†’ ์ถœ๋ ฅ: {output.shape}")
222
223
224# ============================================
225# 6. Full Transformer Encoder
226# ============================================
227print("\n[6] Transformer Encoder")
228print("-" * 40)
229
230class TransformerEncoder(nn.Module):
231    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
232        super().__init__()
233        self.layers = nn.ModuleList([
234            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
235            for _ in range(num_layers)
236        ])
237        self.norm = nn.LayerNorm(d_model)
238
239    def forward(self, x, mask=None):
240        for layer in self.layers:
241            x = layer(x, mask)
242        return self.norm(x)
243
244encoder = TransformerEncoder(num_layers=6, d_model=64, num_heads=8, d_ff=256)
245x = torch.randn(2, 10, 64)
246output = encoder(x)
247print(f"Transformer Encoder ์ถœ๋ ฅ: {output.shape}")
248print(f"ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {sum(p.numel() for p in encoder.parameters()):,}")
249
250
251# ============================================
252# 7. Transformer ๋ถ„๋ฅ˜๊ธฐ
253# ============================================
254print("\n[7] Transformer ๋ถ„๋ฅ˜๊ธฐ")
255print("-" * 40)
256
257class TransformerClassifier(nn.Module):
258    def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes,
259                 d_ff=None, max_len=512, dropout=0.1):
260        super().__init__()
261        d_ff = d_ff or d_model * 4
262
263        self.embedding = nn.Embedding(vocab_size, d_model)
264        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
265        self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout)
266        self.fc = nn.Linear(d_model, num_classes)
267
268        self._init_weights()
269
270    def _init_weights(self):
271        for p in self.parameters():
272            if p.dim() > 1:
273                nn.init.xavier_uniform_(p)
274
275    def forward(self, x, mask=None):
276        # x: (batch, seq) - ํ† ํฐ ์ธ๋ฑ์Šค
277        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
278        x = self.pos_encoding(x)
279        x = self.encoder(x, mask)
280
281        # ํ‰๊ท  ํ’€๋ง ๋˜๋Š” [CLS] ํ† ํฐ
282        x = x.mean(dim=1)
283
284        return self.fc(x)
285
286model = TransformerClassifier(
287    vocab_size=10000,
288    d_model=128,
289    num_heads=8,
290    num_layers=4,
291    num_classes=5
292)
293
294x = torch.randint(0, 10000, (4, 32))
295output = model(x)
296print(f"๋ถ„๋ฅ˜๊ธฐ ์ž…๋ ฅ: {x.shape}")
297print(f"๋ถ„๋ฅ˜๊ธฐ ์ถœ๋ ฅ: {output.shape}")
298print(f"ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {sum(p.numel() for p in model.parameters()):,}")
299
300
301# ============================================
302# 8. PyTorch ๋‚ด์žฅ Transformer
303# ============================================
304print("\n[8] PyTorch ๋‚ด์žฅ Transformer")
305print("-" * 40)
306
307# nn.TransformerEncoder ์‚ฌ์šฉ
308pytorch_encoder_layer = nn.TransformerEncoderLayer(
309    d_model=64,
310    nhead=8,
311    dim_feedforward=256,
312    dropout=0.1,
313    batch_first=True  # PyTorch 1.9+
314)
315
316pytorch_encoder = nn.TransformerEncoder(pytorch_encoder_layer, num_layers=6)
317
318x = torch.randn(2, 10, 64)
319output = pytorch_encoder(x)
320print(f"PyTorch Transformer ์ถœ๋ ฅ: {output.shape}")
321
322
323# ============================================
324# 9. Attention ์‹œ๊ฐํ™”
325# ============================================
326print("\n[9] Attention ์‹œ๊ฐํ™”")
327print("-" * 40)
328
329def visualize_attention(attention_weights, tokens=None):
330    """Attention ๊ฐ€์ค‘์น˜ ์‹œ๊ฐํ™”"""
331    weights = attention_weights[0, 0].detach().numpy()  # ์ฒซ ๋ฐฐ์น˜, ์ฒซ ํ—ค๋“œ
332
333    plt.figure(figsize=(8, 6))
334    plt.imshow(weights, cmap='Blues')
335    plt.colorbar()
336
337    if tokens:
338        plt.xticks(range(len(tokens)), tokens, rotation=45)
339        plt.yticks(range(len(tokens)), tokens)
340
341    plt.xlabel('Key')
342    plt.ylabel('Query')
343    plt.title('Attention Weights')
344    plt.tight_layout()
345    plt.savefig('attention_visualization.png', dpi=100)
346    plt.close()
347    print("๊ทธ๋ž˜ํ”„ ์ €์žฅ: attention_visualization.png")
348
349# ์˜ˆ์‹œ attention ์‹œ๊ฐํ™”
350mha = MultiHeadAttention(d_model=64, num_heads=8)
351x = torch.randn(1, 6, 64)
352_, weights = mha(x, x, x)
353visualize_attention(weights, ['The', 'cat', 'sat', 'on', 'mat', '.'])
354
355
356# ============================================
357# 10. Causal Mask (๋””์ฝ”๋”์šฉ)
358# ============================================
359print("\n[10] Causal Mask")
360print("-" * 40)
361
362def create_causal_mask(size):
363    """๋ฏธ๋ž˜ ํ† ํฐ์„ ๋ณผ ์ˆ˜ ์—†๊ฒŒ ํ•˜๋Š” ๋งˆ์Šคํฌ"""
364    mask = torch.triu(torch.ones(size, size), diagonal=1)
365    return mask == 0  # True: ์ฐธ์กฐ ๊ฐ€๋Šฅ, False: ๋งˆ์Šคํ‚น
366
367mask = create_causal_mask(5)
368print(f"Causal Mask (5x5):\n{mask.int()}")
369
370
371# ============================================
372# ์ •๋ฆฌ
373# ============================================
374print("\n" + "=" * 60)
375print("Attention & Transformer ์ •๋ฆฌ")
376print("=" * 60)
377
378summary = """
379Scaled Dot-Product Attention:
380    scores = Q @ K.T / sqrt(d_k)
381    weights = softmax(scores)
382    output = weights @ V
383
384Multi-Head Attention:
385    - ์—ฌ๋Ÿฌ ํ—ค๋“œ๊ฐ€ ๋‹ค๋ฅธ ๊ด€๊ณ„ ํ•™์Šต
386    - ๊ฐ ํ—ค๋“œ: d_k = d_model / num_heads
387
388Transformer Encoder:
389    - Self-Attention + FFN
390    - Residual + LayerNorm
391    - Positional Encoding
392
393PyTorch ๋‚ด์žฅ:
394    encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
395    encoder = nn.TransformerEncoder(encoder_layer, num_layers)
396
397ํ•ต์‹ฌ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ:
398    - d_model: ๋ชจ๋ธ ์ฐจ์› (512)
399    - num_heads: ํ—ค๋“œ ์ˆ˜ (8)
400    - d_ff: FFN ์ฐจ์› (2048)
401    - num_layers: ์ธต ์ˆ˜ (6)
402"""
403print(summary)
404print("=" * 60)