transformer_lowlevel.py

Download
python 529 lines 15.3 KB
  1"""
  2Transformer - PyTorch Low-Level ๊ตฌํ˜„
  3
  4์ด ํŒŒ์ผ์€ Transformer๋ฅผ PyTorch ๊ธฐ๋ณธ ์—ฐ์‚ฐ๋งŒ์œผ๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5nn.TransformerEncoder, nn.MultiheadAttention ๋“ฑ ๊ณ ์ˆ˜์ค€ API๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ 
  6์ง์ ‘ attention๊ณผ FFN์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  7
  8๋…ผ๋ฌธ: "Attention Is All You Need" (Vaswani et al., 2017)
  9
 10ํ•™์Šต ๋ชฉํ‘œ:
 111. Scaled Dot-Product Attention ๊ตฌํ˜„
 122. Multi-Head Attention ๊ตฌํ˜„
 133. Positional Encoding ๊ตฌํ˜„
 144. Encoder/Decoder ๋ธ”๋ก ๊ตฌํ˜„
 15"""
 16
 17import torch
 18import torch.nn as nn
 19import torch.nn.functional as F
 20import math
 21
 22
 23def scaled_dot_product_attention(
 24    query: torch.Tensor,
 25    key: torch.Tensor,
 26    value: torch.Tensor,
 27    mask: torch.Tensor = None,
 28    dropout: nn.Dropout = None,
 29) -> tuple[torch.Tensor, torch.Tensor]:
 30    """
 31    Scaled Dot-Product Attention
 32
 33    Attention(Q, K, V) = softmax(QK^T / โˆšd_k) V
 34
 35    Args:
 36        query: (batch, n_heads, seq_len, d_k)
 37        key: (batch, n_heads, seq_len, d_k)
 38        value: (batch, n_heads, seq_len, d_v)
 39        mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
 40        dropout: Dropout layer
 41
 42    Returns:
 43        output: (batch, n_heads, seq_len, d_v)
 44        attention_weights: (batch, n_heads, seq_len, seq_len)
 45    """
 46    d_k = query.size(-1)
 47
 48    # 1. QK^T: Query์™€ Key์˜ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
 49    # (batch, heads, seq, d_k) @ (batch, heads, d_k, seq) โ†’ (batch, heads, seq, seq)
 50    scores = torch.matmul(query, key.transpose(-2, -1))
 51
 52    # 2. Scaling: โˆšd_k๋กœ ๋‚˜๋ˆ” (softmax ์•ˆ์ •์„ฑ)
 53    scores = scores / math.sqrt(d_k)
 54
 55    # 3. Masking (optional)
 56    if mask is not None:
 57        # mask๊ฐ€ True์ธ ์œ„์น˜๋ฅผ -inf๋กœ ์„ค์ • (softmax ํ›„ 0์ด ๋จ)
 58        scores = scores.masked_fill(mask, float('-inf'))
 59
 60    # 4. Softmax: ํ™•๋ฅ  ๋ถ„ํฌ๋กœ ๋ณ€ํ™˜
 61    attention_weights = F.softmax(scores, dim=-1)
 62
 63    # 5. Dropout (ํ•™์Šต ์‹œ)
 64    if dropout is not None:
 65        attention_weights = dropout(attention_weights)
 66
 67    # 6. Weighted sum of values
 68    output = torch.matmul(attention_weights, value)
 69
 70    return output, attention_weights
 71
 72
 73class MultiHeadAttentionLowLevel(nn.Module):
 74    """
 75    Multi-Head Attention (Low-Level ๊ตฌํ˜„)
 76
 77    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
 78
 79    nn.MultiheadAttention์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ง์ ‘ ๊ตฌํ˜„
 80    """
 81
 82    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
 83        """
 84        Args:
 85            d_model: ๋ชจ๋ธ ์ฐจ์›
 86            n_heads: attention head ์ˆ˜
 87            dropout: dropout ๋น„์œจ
 88        """
 89        super().__init__()
 90
 91        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
 92
 93        self.d_model = d_model
 94        self.n_heads = n_heads
 95        self.d_k = d_model // n_heads  # ๊ฐ head์˜ ์ฐจ์›
 96
 97        # Q, K, V projection (ํ•ฉ์ณ์„œ ํ•œ ๋ฒˆ์—)
 98        # nn.Linear ๋Œ€์‹  ์ง์ ‘ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ด€๋ฆฌ๋„ ๊ฐ€๋Šฅ
 99        self.W_q = nn.Linear(d_model, d_model, bias=False)
100        self.W_k = nn.Linear(d_model, d_model, bias=False)
101        self.W_v = nn.Linear(d_model, d_model, bias=False)
102
103        # Output projection
104        self.W_o = nn.Linear(d_model, d_model, bias=False)
105
106        self.dropout = nn.Dropout(dropout)
107
108    def forward(
109        self,
110        query: torch.Tensor,
111        key: torch.Tensor,
112        value: torch.Tensor,
113        mask: torch.Tensor = None,
114    ) -> torch.Tensor:
115        """
116        Args:
117            query: (batch, seq_len, d_model)
118            key: (batch, seq_len, d_model)
119            value: (batch, seq_len, d_model)
120            mask: (batch, seq_len) or (batch, seq_len, seq_len)
121
122        Returns:
123            output: (batch, seq_len, d_model)
124        """
125        batch_size = query.size(0)
126
127        # 1. Linear projections
128        # (batch, seq, d_model) โ†’ (batch, seq, d_model)
129        Q = self.W_q(query)
130        K = self.W_k(key)
131        V = self.W_v(value)
132
133        # 2. Split into multiple heads
134        # (batch, seq, d_model) โ†’ (batch, seq, n_heads, d_k) โ†’ (batch, n_heads, seq, d_k)
135        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
136        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
137        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
138
139        # 3. Mask ์ฐจ์› ์กฐ์ • (broadcasting์„ ์œ„ํ•ด)
140        if mask is not None:
141            if mask.dim() == 2:
142                # (batch, seq) โ†’ (batch, 1, 1, seq)
143                mask = mask.unsqueeze(1).unsqueeze(2)
144            elif mask.dim() == 3:
145                # (batch, seq, seq) โ†’ (batch, 1, seq, seq)
146                mask = mask.unsqueeze(1)
147
148        # 4. Attention
149        attn_output, _ = scaled_dot_product_attention(Q, K, V, mask, self.dropout)
150
151        # 5. Concat heads
152        # (batch, n_heads, seq, d_k) โ†’ (batch, seq, n_heads, d_k) โ†’ (batch, seq, d_model)
153        attn_output = attn_output.transpose(1, 2).contiguous()
154        attn_output = attn_output.view(batch_size, -1, self.d_model)
155
156        # 6. Output projection
157        output = self.W_o(attn_output)
158
159        return output
160
161
162class PositionalEncoding(nn.Module):
163    """
164    Sinusoidal Positional Encoding
165
166    PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
167    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
168    """
169
170    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
171        super().__init__()
172        self.dropout = nn.Dropout(dropout)
173
174        # ์œ„์น˜ ์ธ์ฝ”๋”ฉ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ
175        pe = torch.zeros(max_len, d_model)
176        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
177
178        # 10000^(2i/d_model) = exp(2i * log(10000) / d_model)
179        div_term = torch.exp(
180            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
181        )
182
183        pe[:, 0::2] = torch.sin(position * div_term)  # ์ง์ˆ˜ ์ธ๋ฑ์Šค
184        pe[:, 1::2] = torch.cos(position * div_term)  # ํ™€์ˆ˜ ์ธ๋ฑ์Šค
185
186        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
187
188        # ํ•™์Šต๋˜์ง€ ์•Š๋Š” ๋ฒ„ํผ๋กœ ๋“ฑ๋ก
189        self.register_buffer('pe', pe)
190
191    def forward(self, x: torch.Tensor) -> torch.Tensor:
192        """
193        Args:
194            x: (batch, seq_len, d_model)
195
196        Returns:
197            x + PE: (batch, seq_len, d_model)
198        """
199        seq_len = x.size(1)
200        x = x + self.pe[:, :seq_len, :]
201        return self.dropout(x)
202
203
204class FeedForwardLowLevel(nn.Module):
205    """
206    Position-wise Feed-Forward Network
207
208    FFN(x) = GELU(xW_1 + b_1)W_2 + b_2
209
210    ๋ณดํ†ต d_ff = 4 * d_model (expansion)
211    """
212
213    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
214        super().__init__()
215
216        if d_ff is None:
217            d_ff = 4 * d_model
218
219        self.linear1 = nn.Linear(d_model, d_ff)
220        self.linear2 = nn.Linear(d_ff, d_model)
221        self.dropout = nn.Dropout(dropout)
222
223    def forward(self, x: torch.Tensor) -> torch.Tensor:
224        """
225        Args:
226            x: (batch, seq_len, d_model)
227
228        Returns:
229            output: (batch, seq_len, d_model)
230        """
231        # GELU activation (์›๋ž˜ ๋…ผ๋ฌธ์€ ReLU์ง€๋งŒ ํ˜„๋Œ€๋Š” GELU ์„ ํ˜ธ)
232        x = self.linear1(x)
233        x = F.gelu(x)
234        x = self.dropout(x)
235        x = self.linear2(x)
236        return x
237
238
239class TransformerEncoderBlock(nn.Module):
240    """
241    Transformer Encoder Block
242
243    ๊ตฌ์กฐ:
244    x โ†’ LayerNorm โ†’ MultiHeadAttention โ†’ Dropout โ†’ Add(x) โ†’
245      โ†’ LayerNorm โ†’ FeedForward โ†’ Dropout โ†’ Add(x) โ†’ output
246    """
247
248    def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
249        super().__init__()
250
251        self.attention = MultiHeadAttentionLowLevel(d_model, n_heads, dropout)
252        self.feed_forward = FeedForwardLowLevel(d_model, d_ff, dropout)
253
254        self.norm1 = nn.LayerNorm(d_model)
255        self.norm2 = nn.LayerNorm(d_model)
256
257        self.dropout = nn.Dropout(dropout)
258
259    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
260        """
261        Args:
262            x: (batch, seq_len, d_model)
263            mask: padding mask
264
265        Returns:
266            output: (batch, seq_len, d_model)
267        """
268        # Pre-norm (ํ˜„๋Œ€์  ๋ฐฉ์‹, ์›๋ž˜ ๋…ผ๋ฌธ์€ Post-norm)
269        # Self-Attention + Residual
270        normed = self.norm1(x)
271        attn_out = self.attention(normed, normed, normed, mask)
272        x = x + self.dropout(attn_out)
273
274        # Feed-Forward + Residual
275        normed = self.norm2(x)
276        ff_out = self.feed_forward(normed)
277        x = x + self.dropout(ff_out)
278
279        return x
280
281
282class TransformerDecoderBlock(nn.Module):
283    """
284    Transformer Decoder Block
285
286    ๊ตฌ์กฐ:
287    x โ†’ LayerNorm โ†’ MaskedSelfAttention โ†’ Add(x) โ†’
288      โ†’ LayerNorm โ†’ CrossAttention(encoder_output) โ†’ Add(x) โ†’
289      โ†’ LayerNorm โ†’ FeedForward โ†’ Add(x) โ†’ output
290    """
291
292    def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
293        super().__init__()
294
295        self.self_attention = MultiHeadAttentionLowLevel(d_model, n_heads, dropout)
296        self.cross_attention = MultiHeadAttentionLowLevel(d_model, n_heads, dropout)
297        self.feed_forward = FeedForwardLowLevel(d_model, d_ff, dropout)
298
299        self.norm1 = nn.LayerNorm(d_model)
300        self.norm2 = nn.LayerNorm(d_model)
301        self.norm3 = nn.LayerNorm(d_model)
302
303        self.dropout = nn.Dropout(dropout)
304
305    def forward(
306        self,
307        x: torch.Tensor,
308        encoder_output: torch.Tensor,
309        self_mask: torch.Tensor = None,
310        cross_mask: torch.Tensor = None,
311    ) -> torch.Tensor:
312        """
313        Args:
314            x: decoder input (batch, tgt_len, d_model)
315            encoder_output: encoder output (batch, src_len, d_model)
316            self_mask: causal mask for self-attention
317            cross_mask: padding mask for cross-attention
318
319        Returns:
320            output: (batch, tgt_len, d_model)
321        """
322        # Masked Self-Attention
323        normed = self.norm1(x)
324        attn_out = self.self_attention(normed, normed, normed, self_mask)
325        x = x + self.dropout(attn_out)
326
327        # Cross-Attention (query: decoder, key/value: encoder)
328        normed = self.norm2(x)
329        cross_out = self.cross_attention(normed, encoder_output, encoder_output, cross_mask)
330        x = x + self.dropout(cross_out)
331
332        # Feed-Forward
333        normed = self.norm3(x)
334        ff_out = self.feed_forward(normed)
335        x = x + self.dropout(ff_out)
336
337        return x
338
339
340class TransformerLowLevel(nn.Module):
341    """
342    ์ „์ฒด Transformer ๋ชจ๋ธ (Encoder-Decoder)
343
344    ๋ฒˆ์—ญ, ์š”์•ฝ ๋“ฑ seq2seq ํƒœ์Šคํฌ์šฉ
345    """
346
347    def __init__(
348        self,
349        src_vocab_size: int,
350        tgt_vocab_size: int,
351        d_model: int = 512,
352        n_heads: int = 8,
353        n_encoder_layers: int = 6,
354        n_decoder_layers: int = 6,
355        d_ff: int = 2048,
356        max_len: int = 5000,
357        dropout: float = 0.1,
358    ):
359        super().__init__()
360
361        # Embeddings
362        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
363        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
364
365        # Positional Encoding
366        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
367
368        # Encoder
369        self.encoder_layers = nn.ModuleList([
370            TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
371            for _ in range(n_encoder_layers)
372        ])
373        self.encoder_norm = nn.LayerNorm(d_model)
374
375        # Decoder
376        self.decoder_layers = nn.ModuleList([
377            TransformerDecoderBlock(d_model, n_heads, d_ff, dropout)
378            for _ in range(n_decoder_layers)
379        ])
380        self.decoder_norm = nn.LayerNorm(d_model)
381
382        # Output
383        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
384
385        # Scaling factor for embeddings
386        self.scale = math.sqrt(d_model)
387
388    def create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
389        """
390        Causal mask: ๋ฏธ๋ž˜ ํ† ํฐ์„ ๋ชป ๋ณด๊ฒŒ ํ•˜๋Š” ๋งˆ์Šคํฌ
391
392        Returns:
393            mask: (seq_len, seq_len) - True = ๋งˆ์Šคํ‚น
394        """
395        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
396        return mask.bool()
397
398    def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
399        """
400        Encoder forward pass
401
402        Args:
403            src: source tokens (batch, src_len)
404            src_mask: padding mask
405
406        Returns:
407            encoder_output: (batch, src_len, d_model)
408        """
409        # Embedding + Positional Encoding
410        x = self.src_embedding(src) * self.scale
411        x = self.pos_encoding(x)
412
413        # Encoder layers
414        for layer in self.encoder_layers:
415            x = layer(x, src_mask)
416
417        x = self.encoder_norm(x)
418        return x
419
420    def decode(
421        self,
422        tgt: torch.Tensor,
423        encoder_output: torch.Tensor,
424        tgt_mask: torch.Tensor = None,
425        memory_mask: torch.Tensor = None,
426    ) -> torch.Tensor:
427        """
428        Decoder forward pass
429
430        Args:
431            tgt: target tokens (batch, tgt_len)
432            encoder_output: (batch, src_len, d_model)
433            tgt_mask: causal mask
434            memory_mask: cross-attention mask
435
436        Returns:
437            decoder_output: (batch, tgt_len, d_model)
438        """
439        # Embedding + Positional Encoding
440        x = self.tgt_embedding(tgt) * self.scale
441        x = self.pos_encoding(x)
442
443        # Causal mask
444        if tgt_mask is None:
445            tgt_mask = self.create_causal_mask(tgt.size(1), tgt.device)
446
447        # Decoder layers
448        for layer in self.decoder_layers:
449            x = layer(x, encoder_output, tgt_mask, memory_mask)
450
451        x = self.decoder_norm(x)
452        return x
453
454    def forward(
455        self,
456        src: torch.Tensor,
457        tgt: torch.Tensor,
458        src_mask: torch.Tensor = None,
459        tgt_mask: torch.Tensor = None,
460    ) -> torch.Tensor:
461        """
462        ์ „์ฒด forward pass
463
464        Args:
465            src: source tokens (batch, src_len)
466            tgt: target tokens (batch, tgt_len)
467
468        Returns:
469            logits: (batch, tgt_len, vocab_size)
470        """
471        encoder_output = self.encode(src, src_mask)
472        decoder_output = self.decode(tgt, encoder_output, tgt_mask, src_mask)
473        logits = self.output_projection(decoder_output)
474        return logits
475
476
477def main():
478    """ํ…Œ์ŠคํŠธ ์‹คํ–‰"""
479    print("=" * 60)
480    print("Transformer - PyTorch Low-Level ๊ตฌํ˜„")
481    print("=" * 60)
482
483    # ์„ค์ •
484    src_vocab_size = 10000
485    tgt_vocab_size = 10000
486    d_model = 256
487    n_heads = 8
488    n_layers = 4
489    batch_size = 2
490    src_len = 10
491    tgt_len = 8
492
493    # ๋ชจ๋ธ ์ƒ์„ฑ
494    model = TransformerLowLevel(
495        src_vocab_size=src_vocab_size,
496        tgt_vocab_size=tgt_vocab_size,
497        d_model=d_model,
498        n_heads=n_heads,
499        n_encoder_layers=n_layers,
500        n_decoder_layers=n_layers,
501    )
502
503    # ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
504    total_params = sum(p.numel() for p in model.parameters())
505    print(f"\n๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {total_params:,}")
506
507    # ๋”๋ฏธ ๋ฐ์ดํ„ฐ
508    src = torch.randint(0, src_vocab_size, (batch_size, src_len))
509    tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))
510
511    print(f"\nInput shapes:")
512    print(f"  Source: {src.shape}")
513    print(f"  Target: {tgt.shape}")
514
515    # Forward pass
516    model.eval()
517    with torch.no_grad():
518        logits = model(src, tgt)
519
520    print(f"\nOutput shape: {logits.shape}")
521    print(f"  Expected: (batch={batch_size}, tgt_len={tgt_len}, vocab={tgt_vocab_size})")
522
523    # Attention ํŒจํ„ด ์‹œ๊ฐํ™” (optional)
524    print("\nํ…Œ์ŠคํŠธ ์™„๋ฃŒ!")
525
526
527if __name__ == "__main__":
528    main()