03. Transformer ๋ณต์Šต

03. Transformer ๋ณต์Šต

ํ•™์Šต ๋ชฉํ‘œ

  • NLP ๊ด€์ ์—์„œ Transformer ์ดํ•ด
  • Encoder์™€ Decoder ๊ตฌ์กฐ
  • ์–ธ์–ด ๋ชจ๋ธ๋ง ๊ด€์ ์˜ Attention
  • BERT/GPT ๊ธฐ๋ฐ˜ ๊ตฌ์กฐ ์ดํ•ด

1. Transformer ๊ฐœ์š”

๊ตฌ์กฐ ์š”์•ฝ

์ธ์ฝ”๋” (BERT ์Šคํƒ€์ผ):
    ์ž…๋ ฅ โ†’ [Embedding + Positional] โ†’ [Self-Attention + FFN] ร— N โ†’ ์ถœ๋ ฅ

๋””์ฝ”๋” (GPT ์Šคํƒ€์ผ):
    ์ž…๋ ฅ โ†’ [Embedding + Positional] โ†’ [Masked Self-Attention + FFN] ร— N โ†’ ์ถœ๋ ฅ

์ธ์ฝ”๋”-๋””์ฝ”๋” (T5 ์Šคํƒ€์ผ):
    ์ž…๋ ฅ โ†’ ์ธ์ฝ”๋” โ†’ [Cross-Attention] โ†’ ๋””์ฝ”๋” โ†’ ์ถœ๋ ฅ

NLP์—์„œ์˜ ์—ญํ• 

๋ชจ๋ธ ๊ตฌ์กฐ ์šฉ๋„
BERT ์ธ์ฝ”๋” only ๋ถ„๋ฅ˜, QA, NER
GPT ๋””์ฝ”๋” only ํ…์ŠคํŠธ ์ƒ์„ฑ
T5, BART ์ธ์ฝ”๋”-๋””์ฝ”๋” ๋ฒˆ์—ญ, ์š”์•ฝ

2. Self-Attention (NLP ๊ด€์ )

๋ฌธ์žฅ ๋‚ด ๊ด€๊ณ„ ํ•™์Šต

"The cat sat on the mat because it was tired"

"it" โ†’ Attention โ†’ "cat" (๋†’์€ ๊ฐ€์ค‘์น˜)
                โ†’ "mat" (๋‚ฎ์€ ๊ฐ€์ค‘์น˜)

๋ชจ๋ธ์ด ๋Œ€๋ช…์‚ฌ "it"์ด "cat"์„ ์ง€์นญํ•จ์„ ํ•™์Šต

Query, Key, Value

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Q, K, V ๊ณ„์‚ฐ
        Q = self.W_q(x)  # (batch, seq, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        # Multi-head ๋ถ„ํ• 
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # (batch, num_heads, seq, d_k)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        # ํ—ค๋“œ ๊ฒฐํ•ฉ
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(context)

        return output, attention_weights

3. Causal Masking (GPT ์Šคํƒ€์ผ)

์ž๊ธฐํšŒ๊ท€ ์–ธ์–ด ๋ชจ๋ธ

"I love NLP" ํ•™์Šต:
    ์ž…๋ ฅ: [I]         โ†’ ์˜ˆ์ธก: love
    ์ž…๋ ฅ: [I, love]   โ†’ ์˜ˆ์ธก: NLP
    ์ž…๋ ฅ: [I, love, NLP] โ†’ ์˜ˆ์ธก: <eos>

๋ฏธ๋ž˜ ํ† ํฐ์„ ๋ณด๋ฉด ์•ˆ ๋จ โ†’ Causal Mask ํ•„์š”

Causal Mask ๊ตฌํ˜„

def create_causal_mask(seq_len):
    """ํ•˜์‚ผ๊ฐ ๋งˆ์Šคํฌ ์ƒ์„ฑ (๋ฏธ๋ž˜ ํ† ํฐ ์ฐจ๋‹จ)"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask  # 1 = ์ฐธ์กฐ ๊ฐ€๋Šฅ, 0 = ๋งˆ์Šคํ‚น

# ์˜ˆ์‹œ (seq_len=4)
# [[1, 0, 0, 0],
#  [1, 1, 0, 0],
#  [1, 1, 1, 0],
#  [1, 1, 1, 1]]

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, max_len=512):
        super().__init__()
        self.attention = SelfAttention(d_model, num_heads)
        # ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ๋œ ๋งˆ์Šคํฌ ๋“ฑ๋ก
        mask = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer('mask', mask)

    def forward(self, x):
        seq_len = x.size(1)
        mask = self.mask[:seq_len, :seq_len]
        return self.attention(x, mask)

4. Encoder vs Decoder

์ธ์ฝ”๋” (์–‘๋ฐฉํ–ฅ)

class TransformerEncoderBlock(nn.Module):
    """BERT ์Šคํƒ€์ผ ์ธ์ฝ”๋” ๋ธ”๋ก"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = SelfAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, padding_mask=None):
        # Self-Attention (์–‘๋ฐฉํ–ฅ)
        attn_out, _ = self.self_attn(x, padding_mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed Forward
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))

        return x

๋””์ฝ”๋” (๋‹จ๋ฐฉํ–ฅ)

class TransformerDecoderBlock(nn.Module):
    """GPT ์Šคํƒ€์ผ ๋””์ฝ”๋” ๋ธ”๋ก"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = CausalSelfAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Masked Self-Attention (๋‹จ๋ฐฉํ–ฅ)
        attn_out, _ = self.self_attn(x)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed Forward
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))

        return x

5. Positional Encoding

Sinusoidal (์›๋ณธ Transformer)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

Learnable (BERT, GPT)

class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return x + self.pos_embedding(positions)

6. Complete Transformer Model

GPT-์Šคํƒ€์ผ ์–ธ์–ด ๋ชจ๋ธ

class GPTModel(nn.Module):
    def __init__(self, vocab_size, d_model=768, num_heads=12,
                 num_layers=12, d_ff=3072, max_len=1024, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # ํ† ํฐ + ์œ„์น˜ ์ž„๋ฒ ๋”ฉ
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_len, d_model)

        # Decoder ๋ธ”๋ก
        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying (์„ ํƒ)
        self.head.weight = self.token_embedding.weight

    def forward(self, x):
        # x: (batch, seq_len)
        batch_size, seq_len = x.shape

        # ์ž„๋ฒ ๋”ฉ
        tok_emb = self.token_embedding(x)
        pos = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(pos)
        x = tok_emb + pos_emb

        # Transformer ๋ธ”๋ก
        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.head(x)  # (batch, seq, vocab_size)

        return logits

    def generate(self, idx, max_new_tokens, temperature=1.0):
        """์ž๊ธฐํšŒ๊ท€ ํ…์ŠคํŠธ ์ƒ์„ฑ"""
        for _ in range(max_new_tokens):
            # ๋งˆ์ง€๋ง‰ ์œ„์น˜์˜ logits
            logits = self(idx)[:, -1, :]  # (batch, vocab)
            probs = F.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_token], dim=1)
        return idx

BERT-์Šคํƒ€์ผ ์ธ์ฝ”๋”

class BERTModel(nn.Module):
    def __init__(self, vocab_size, d_model=768, num_heads=12,
                 num_layers=12, d_ff=3072, max_len=512, dropout=0.1):
        super().__init__()

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_len, d_model)
        self.segment_embedding = nn.Embedding(2, d_model)  # ๋ฌธ์žฅ ๊ตฌ๋ถ„

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)

    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        batch_size, seq_len = input_ids.shape

        # ์ž„๋ฒ ๋”ฉ ๊ฒฐํ•ฉ
        tok_emb = self.token_embedding(input_ids)
        pos = torch.arange(seq_len, device=input_ids.device)
        pos_emb = self.position_embedding(pos)

        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids)
        seg_emb = self.segment_embedding(segment_ids)

        x = tok_emb + pos_emb + seg_emb

        # Transformer ๋ธ”๋ก
        for block in self.blocks:
            x = block(x, attention_mask)

        return self.ln_f(x)

7. ํ•™์Šต ๋ชฉํ‘œ๋ณ„ ๋น„๊ต

Masked Language Modeling (BERT)

์ž…๋ ฅ: "The [MASK] sat on the mat"
์˜ˆ์ธก: [MASK] โ†’ "cat"

15% ํ† ํฐ์„ ๋งˆ์Šคํ‚นํ•˜์—ฌ ์˜ˆ์ธก
์–‘๋ฐฉํ–ฅ ๋ฌธ๋งฅ ํ™œ์šฉ

Causal Language Modeling (GPT)

์ž…๋ ฅ: "The cat sat on"
์˜ˆ์ธก: "the" "cat" "sat" "on" "the" "mat"

๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก
๋‹จ๋ฐฉํ–ฅ (์™ผ์ชฝโ†’์˜ค๋ฅธ์ชฝ)

Seq2Seq (T5, BART)

์ž…๋ ฅ: "translate English to French: Hello"
์ถœ๋ ฅ: "Bonjour"

์ธ์ฝ”๋”: ์ž…๋ ฅ ์ดํ•ด
๋””์ฝ”๋”: ์ถœ๋ ฅ ์ƒ์„ฑ

8. PyTorch ๋‚ด์žฅ Transformer

import torch.nn as nn

# ์ธ์ฝ”๋”
encoder_layer = nn.TransformerEncoderLayer(
    d_model=512,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.1,
    batch_first=True
)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

# ๋””์ฝ”๋”
decoder_layer = nn.TransformerDecoderLayer(
    d_model=512,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.1,
    batch_first=True
)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)

# ์‚ฌ์šฉ
x = torch.randn(32, 100, 512)  # (batch, seq, d_model)
encoded = encoder(x)
decoded = decoder(x, encoded)

์ •๋ฆฌ

๋ชจ๋ธ ๋น„๊ต

ํ•ญ๋ชฉ BERT (์ธ์ฝ”๋”) GPT (๋””์ฝ”๋”) T5 (Enc-Dec)
Attention ์–‘๋ฐฉํ–ฅ ๋‹จ๋ฐฉํ–ฅ (Causal) ์–‘๋ฐฉํ–ฅ + ๋‹จ๋ฐฉํ–ฅ
ํ•™์Šต MLM + NSP ๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก Denoising
์ถœ๋ ฅ ๋ฌธ๋งฅ ๋ฒกํ„ฐ ์ƒ์„ฑ ์ƒ์„ฑ
์šฉ๋„ ๋ถ„๋ฅ˜, QA ์ƒ์„ฑ, ๋Œ€ํ™” ๋ฒˆ์—ญ, ์š”์•ฝ

ํ•ต์‹ฌ ์ฝ”๋“œ

# Causal Mask
mask = torch.tril(torch.ones(seq_len, seq_len))
scores = scores.masked_fill(mask == 0, -1e9)

# Multi-Head Attention ๋ถ„ํ• 
Q = Q.view(batch, seq, num_heads, d_k).transpose(1, 2)

# Scaled Dot-Product
scores = Q @ K.T / sqrt(d_k)
attn = softmax(scores) @ V

๋‹ค์Œ ๋‹จ๊ณ„

04_BERT_Understanding.md์—์„œ BERT์˜ ๊ตฌ์กฐ์™€ ํ•™์Šต ๋ฐฉ๋ฒ•์„ ์ƒ์„ธํžˆ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

to navigate between lessons