1"""
2Transformer - PyTorch Low-Level ๊ตฌํ
3
4์ด ํ์ผ์ Transformer๋ฅผ PyTorch ๊ธฐ๋ณธ ์ฐ์ฐ๋ง์ผ๋ก ๊ตฌํํฉ๋๋ค.
5nn.TransformerEncoder, nn.MultiheadAttention ๋ฑ ๊ณ ์์ค API๋ฅผ ์ฌ์ฉํ์ง ์๊ณ
6์ง์ attention๊ณผ FFN์ ๊ตฌํํฉ๋๋ค.
7
8๋
ผ๋ฌธ: "Attention Is All You Need" (Vaswani et al., 2017)
9
10ํ์ต ๋ชฉํ:
111. Scaled Dot-Product Attention ๊ตฌํ
122. Multi-Head Attention ๊ตฌํ
133. Positional Encoding ๊ตฌํ
144. Encoder/Decoder ๋ธ๋ก ๊ตฌํ
15"""
16
17import torch
18import torch.nn as nn
19import torch.nn.functional as F
20import math
21
22
23def scaled_dot_product_attention(
24 query: torch.Tensor,
25 key: torch.Tensor,
26 value: torch.Tensor,
27 mask: torch.Tensor = None,
28 dropout: nn.Dropout = None,
29) -> tuple[torch.Tensor, torch.Tensor]:
30 """
31 Scaled Dot-Product Attention
32
33 Attention(Q, K, V) = softmax(QK^T / โd_k) V
34
35 Args:
36 query: (batch, n_heads, seq_len, d_k)
37 key: (batch, n_heads, seq_len, d_k)
38 value: (batch, n_heads, seq_len, d_v)
39 mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
40 dropout: Dropout layer
41
42 Returns:
43 output: (batch, n_heads, seq_len, d_v)
44 attention_weights: (batch, n_heads, seq_len, seq_len)
45 """
46 d_k = query.size(-1)
47
48 # 1. QK^T: Query์ Key์ ์ ์ฌ๋ ๊ณ์ฐ
49 # (batch, heads, seq, d_k) @ (batch, heads, d_k, seq) โ (batch, heads, seq, seq)
50 scores = torch.matmul(query, key.transpose(-2, -1))
51
52 # 2. Scaling: โd_k๋ก ๋๋ (softmax ์์ ์ฑ)
53 scores = scores / math.sqrt(d_k)
54
55 # 3. Masking (optional)
56 if mask is not None:
57 # mask๊ฐ True์ธ ์์น๋ฅผ -inf๋ก ์ค์ (softmax ํ 0์ด ๋จ)
58 scores = scores.masked_fill(mask, float('-inf'))
59
60 # 4. Softmax: ํ๋ฅ ๋ถํฌ๋ก ๋ณํ
61 attention_weights = F.softmax(scores, dim=-1)
62
63 # 5. Dropout (ํ์ต ์)
64 if dropout is not None:
65 attention_weights = dropout(attention_weights)
66
67 # 6. Weighted sum of values
68 output = torch.matmul(attention_weights, value)
69
70 return output, attention_weights
71
72
73class MultiHeadAttentionLowLevel(nn.Module):
74 """
75 Multi-Head Attention (Low-Level ๊ตฌํ)
76
77 MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
78
79 nn.MultiheadAttention์ ์ฌ์ฉํ์ง ์๊ณ ์ง์ ๊ตฌํ
80 """
81
82 def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
83 """
84 Args:
85 d_model: ๋ชจ๋ธ ์ฐจ์
86 n_heads: attention head ์
87 dropout: dropout ๋น์จ
88 """
89 super().__init__()
90
91 assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
92
93 self.d_model = d_model
94 self.n_heads = n_heads
95 self.d_k = d_model // n_heads # ๊ฐ head์ ์ฐจ์
96
97 # Q, K, V projection (ํฉ์ณ์ ํ ๋ฒ์)
98 # nn.Linear ๋์ ์ง์ ํ๋ผ๋ฏธํฐ ๊ด๋ฆฌ๋ ๊ฐ๋ฅ
99 self.W_q = nn.Linear(d_model, d_model, bias=False)
100 self.W_k = nn.Linear(d_model, d_model, bias=False)
101 self.W_v = nn.Linear(d_model, d_model, bias=False)
102
103 # Output projection
104 self.W_o = nn.Linear(d_model, d_model, bias=False)
105
106 self.dropout = nn.Dropout(dropout)
107
108 def forward(
109 self,
110 query: torch.Tensor,
111 key: torch.Tensor,
112 value: torch.Tensor,
113 mask: torch.Tensor = None,
114 ) -> torch.Tensor:
115 """
116 Args:
117 query: (batch, seq_len, d_model)
118 key: (batch, seq_len, d_model)
119 value: (batch, seq_len, d_model)
120 mask: (batch, seq_len) or (batch, seq_len, seq_len)
121
122 Returns:
123 output: (batch, seq_len, d_model)
124 """
125 batch_size = query.size(0)
126
127 # 1. Linear projections
128 # (batch, seq, d_model) โ (batch, seq, d_model)
129 Q = self.W_q(query)
130 K = self.W_k(key)
131 V = self.W_v(value)
132
133 # 2. Split into multiple heads
134 # (batch, seq, d_model) โ (batch, seq, n_heads, d_k) โ (batch, n_heads, seq, d_k)
135 Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
136 K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
137 V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
138
139 # 3. Mask ์ฐจ์ ์กฐ์ (broadcasting์ ์ํด)
140 if mask is not None:
141 if mask.dim() == 2:
142 # (batch, seq) โ (batch, 1, 1, seq)
143 mask = mask.unsqueeze(1).unsqueeze(2)
144 elif mask.dim() == 3:
145 # (batch, seq, seq) โ (batch, 1, seq, seq)
146 mask = mask.unsqueeze(1)
147
148 # 4. Attention
149 attn_output, _ = scaled_dot_product_attention(Q, K, V, mask, self.dropout)
150
151 # 5. Concat heads
152 # (batch, n_heads, seq, d_k) โ (batch, seq, n_heads, d_k) โ (batch, seq, d_model)
153 attn_output = attn_output.transpose(1, 2).contiguous()
154 attn_output = attn_output.view(batch_size, -1, self.d_model)
155
156 # 6. Output projection
157 output = self.W_o(attn_output)
158
159 return output
160
161
162class PositionalEncoding(nn.Module):
163 """
164 Sinusoidal Positional Encoding
165
166 PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
167 PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
168 """
169
170 def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
171 super().__init__()
172 self.dropout = nn.Dropout(dropout)
173
174 # ์์น ์ธ์ฝ๋ฉ ๋ฏธ๋ฆฌ ๊ณ์ฐ
175 pe = torch.zeros(max_len, d_model)
176 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
177
178 # 10000^(2i/d_model) = exp(2i * log(10000) / d_model)
179 div_term = torch.exp(
180 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
181 )
182
183 pe[:, 0::2] = torch.sin(position * div_term) # ์ง์ ์ธ๋ฑ์ค
184 pe[:, 1::2] = torch.cos(position * div_term) # ํ์ ์ธ๋ฑ์ค
185
186 pe = pe.unsqueeze(0) # (1, max_len, d_model)
187
188 # ํ์ต๋์ง ์๋ ๋ฒํผ๋ก ๋ฑ๋ก
189 self.register_buffer('pe', pe)
190
191 def forward(self, x: torch.Tensor) -> torch.Tensor:
192 """
193 Args:
194 x: (batch, seq_len, d_model)
195
196 Returns:
197 x + PE: (batch, seq_len, d_model)
198 """
199 seq_len = x.size(1)
200 x = x + self.pe[:, :seq_len, :]
201 return self.dropout(x)
202
203
204class FeedForwardLowLevel(nn.Module):
205 """
206 Position-wise Feed-Forward Network
207
208 FFN(x) = GELU(xW_1 + b_1)W_2 + b_2
209
210 ๋ณดํต d_ff = 4 * d_model (expansion)
211 """
212
213 def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
214 super().__init__()
215
216 if d_ff is None:
217 d_ff = 4 * d_model
218
219 self.linear1 = nn.Linear(d_model, d_ff)
220 self.linear2 = nn.Linear(d_ff, d_model)
221 self.dropout = nn.Dropout(dropout)
222
223 def forward(self, x: torch.Tensor) -> torch.Tensor:
224 """
225 Args:
226 x: (batch, seq_len, d_model)
227
228 Returns:
229 output: (batch, seq_len, d_model)
230 """
231 # GELU activation (์๋ ๋
ผ๋ฌธ์ ReLU์ง๋ง ํ๋๋ GELU ์ ํธ)
232 x = self.linear1(x)
233 x = F.gelu(x)
234 x = self.dropout(x)
235 x = self.linear2(x)
236 return x
237
238
239class TransformerEncoderBlock(nn.Module):
240 """
241 Transformer Encoder Block
242
243 ๊ตฌ์กฐ:
244 x โ LayerNorm โ MultiHeadAttention โ Dropout โ Add(x) โ
245 โ LayerNorm โ FeedForward โ Dropout โ Add(x) โ output
246 """
247
248 def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
249 super().__init__()
250
251 self.attention = MultiHeadAttentionLowLevel(d_model, n_heads, dropout)
252 self.feed_forward = FeedForwardLowLevel(d_model, d_ff, dropout)
253
254 self.norm1 = nn.LayerNorm(d_model)
255 self.norm2 = nn.LayerNorm(d_model)
256
257 self.dropout = nn.Dropout(dropout)
258
259 def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
260 """
261 Args:
262 x: (batch, seq_len, d_model)
263 mask: padding mask
264
265 Returns:
266 output: (batch, seq_len, d_model)
267 """
268 # Pre-norm (ํ๋์ ๋ฐฉ์, ์๋ ๋
ผ๋ฌธ์ Post-norm)
269 # Self-Attention + Residual
270 normed = self.norm1(x)
271 attn_out = self.attention(normed, normed, normed, mask)
272 x = x + self.dropout(attn_out)
273
274 # Feed-Forward + Residual
275 normed = self.norm2(x)
276 ff_out = self.feed_forward(normed)
277 x = x + self.dropout(ff_out)
278
279 return x
280
281
282class TransformerDecoderBlock(nn.Module):
283 """
284 Transformer Decoder Block
285
286 ๊ตฌ์กฐ:
287 x โ LayerNorm โ MaskedSelfAttention โ Add(x) โ
288 โ LayerNorm โ CrossAttention(encoder_output) โ Add(x) โ
289 โ LayerNorm โ FeedForward โ Add(x) โ output
290 """
291
292 def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
293 super().__init__()
294
295 self.self_attention = MultiHeadAttentionLowLevel(d_model, n_heads, dropout)
296 self.cross_attention = MultiHeadAttentionLowLevel(d_model, n_heads, dropout)
297 self.feed_forward = FeedForwardLowLevel(d_model, d_ff, dropout)
298
299 self.norm1 = nn.LayerNorm(d_model)
300 self.norm2 = nn.LayerNorm(d_model)
301 self.norm3 = nn.LayerNorm(d_model)
302
303 self.dropout = nn.Dropout(dropout)
304
305 def forward(
306 self,
307 x: torch.Tensor,
308 encoder_output: torch.Tensor,
309 self_mask: torch.Tensor = None,
310 cross_mask: torch.Tensor = None,
311 ) -> torch.Tensor:
312 """
313 Args:
314 x: decoder input (batch, tgt_len, d_model)
315 encoder_output: encoder output (batch, src_len, d_model)
316 self_mask: causal mask for self-attention
317 cross_mask: padding mask for cross-attention
318
319 Returns:
320 output: (batch, tgt_len, d_model)
321 """
322 # Masked Self-Attention
323 normed = self.norm1(x)
324 attn_out = self.self_attention(normed, normed, normed, self_mask)
325 x = x + self.dropout(attn_out)
326
327 # Cross-Attention (query: decoder, key/value: encoder)
328 normed = self.norm2(x)
329 cross_out = self.cross_attention(normed, encoder_output, encoder_output, cross_mask)
330 x = x + self.dropout(cross_out)
331
332 # Feed-Forward
333 normed = self.norm3(x)
334 ff_out = self.feed_forward(normed)
335 x = x + self.dropout(ff_out)
336
337 return x
338
339
340class TransformerLowLevel(nn.Module):
341 """
342 ์ ์ฒด Transformer ๋ชจ๋ธ (Encoder-Decoder)
343
344 ๋ฒ์ญ, ์์ฝ ๋ฑ seq2seq ํ์คํฌ์ฉ
345 """
346
347 def __init__(
348 self,
349 src_vocab_size: int,
350 tgt_vocab_size: int,
351 d_model: int = 512,
352 n_heads: int = 8,
353 n_encoder_layers: int = 6,
354 n_decoder_layers: int = 6,
355 d_ff: int = 2048,
356 max_len: int = 5000,
357 dropout: float = 0.1,
358 ):
359 super().__init__()
360
361 # Embeddings
362 self.src_embedding = nn.Embedding(src_vocab_size, d_model)
363 self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
364
365 # Positional Encoding
366 self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
367
368 # Encoder
369 self.encoder_layers = nn.ModuleList([
370 TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
371 for _ in range(n_encoder_layers)
372 ])
373 self.encoder_norm = nn.LayerNorm(d_model)
374
375 # Decoder
376 self.decoder_layers = nn.ModuleList([
377 TransformerDecoderBlock(d_model, n_heads, d_ff, dropout)
378 for _ in range(n_decoder_layers)
379 ])
380 self.decoder_norm = nn.LayerNorm(d_model)
381
382 # Output
383 self.output_projection = nn.Linear(d_model, tgt_vocab_size)
384
385 # Scaling factor for embeddings
386 self.scale = math.sqrt(d_model)
387
388 def create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
389 """
390 Causal mask: ๋ฏธ๋ ํ ํฐ์ ๋ชป ๋ณด๊ฒ ํ๋ ๋ง์คํฌ
391
392 Returns:
393 mask: (seq_len, seq_len) - True = ๋ง์คํน
394 """
395 mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
396 return mask.bool()
397
398 def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
399 """
400 Encoder forward pass
401
402 Args:
403 src: source tokens (batch, src_len)
404 src_mask: padding mask
405
406 Returns:
407 encoder_output: (batch, src_len, d_model)
408 """
409 # Embedding + Positional Encoding
410 x = self.src_embedding(src) * self.scale
411 x = self.pos_encoding(x)
412
413 # Encoder layers
414 for layer in self.encoder_layers:
415 x = layer(x, src_mask)
416
417 x = self.encoder_norm(x)
418 return x
419
420 def decode(
421 self,
422 tgt: torch.Tensor,
423 encoder_output: torch.Tensor,
424 tgt_mask: torch.Tensor = None,
425 memory_mask: torch.Tensor = None,
426 ) -> torch.Tensor:
427 """
428 Decoder forward pass
429
430 Args:
431 tgt: target tokens (batch, tgt_len)
432 encoder_output: (batch, src_len, d_model)
433 tgt_mask: causal mask
434 memory_mask: cross-attention mask
435
436 Returns:
437 decoder_output: (batch, tgt_len, d_model)
438 """
439 # Embedding + Positional Encoding
440 x = self.tgt_embedding(tgt) * self.scale
441 x = self.pos_encoding(x)
442
443 # Causal mask
444 if tgt_mask is None:
445 tgt_mask = self.create_causal_mask(tgt.size(1), tgt.device)
446
447 # Decoder layers
448 for layer in self.decoder_layers:
449 x = layer(x, encoder_output, tgt_mask, memory_mask)
450
451 x = self.decoder_norm(x)
452 return x
453
454 def forward(
455 self,
456 src: torch.Tensor,
457 tgt: torch.Tensor,
458 src_mask: torch.Tensor = None,
459 tgt_mask: torch.Tensor = None,
460 ) -> torch.Tensor:
461 """
462 ์ ์ฒด forward pass
463
464 Args:
465 src: source tokens (batch, src_len)
466 tgt: target tokens (batch, tgt_len)
467
468 Returns:
469 logits: (batch, tgt_len, vocab_size)
470 """
471 encoder_output = self.encode(src, src_mask)
472 decoder_output = self.decode(tgt, encoder_output, tgt_mask, src_mask)
473 logits = self.output_projection(decoder_output)
474 return logits
475
476
477def main():
478 """ํ
์คํธ ์คํ"""
479 print("=" * 60)
480 print("Transformer - PyTorch Low-Level ๊ตฌํ")
481 print("=" * 60)
482
483 # ์ค์
484 src_vocab_size = 10000
485 tgt_vocab_size = 10000
486 d_model = 256
487 n_heads = 8
488 n_layers = 4
489 batch_size = 2
490 src_len = 10
491 tgt_len = 8
492
493 # ๋ชจ๋ธ ์์ฑ
494 model = TransformerLowLevel(
495 src_vocab_size=src_vocab_size,
496 tgt_vocab_size=tgt_vocab_size,
497 d_model=d_model,
498 n_heads=n_heads,
499 n_encoder_layers=n_layers,
500 n_decoder_layers=n_layers,
501 )
502
503 # ํ๋ผ๋ฏธํฐ ์
504 total_params = sum(p.numel() for p in model.parameters())
505 print(f"\n๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์: {total_params:,}")
506
507 # ๋๋ฏธ ๋ฐ์ดํฐ
508 src = torch.randint(0, src_vocab_size, (batch_size, src_len))
509 tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))
510
511 print(f"\nInput shapes:")
512 print(f" Source: {src.shape}")
513 print(f" Target: {tgt.shape}")
514
515 # Forward pass
516 model.eval()
517 with torch.no_grad():
518 logits = model(src, tgt)
519
520 print(f"\nOutput shape: {logits.shape}")
521 print(f" Expected: (batch={batch_size}, tgt_len={tgt_len}, vocab={tgt_vocab_size})")
522
523 # Attention ํจํด ์๊ฐํ (optional)
524 print("\nํ
์คํธ ์๋ฃ!")
525
526
527if __name__ == "__main__":
528 main()