03_transformer_nlp.py

Download
python 233 lines 7.6 KB
  1"""
  203. Transformer NLP 예제
  3
  4Transformer μ•„ν‚€ν…μ²˜ 볡슡 및 NLP 적용
  5"""
  6
  7print("=" * 60)
  8print("Transformer NLP")
  9print("=" * 60)
 10
 11try:
 12    import torch
 13    import torch.nn as nn
 14    import torch.nn.functional as F
 15    import math
 16
 17    # ============================================
 18    # 1. Self-Attention
 19    # ============================================
 20    print("\n[1] Self-Attention")
 21    print("-" * 40)
 22
 23    def scaled_dot_product_attention(Q, K, V, mask=None):
 24        """Scaled Dot-Product Attention"""
 25        d_k = K.size(-1)
 26        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
 27
 28        if mask is not None:
 29            scores = scores.masked_fill(mask == 0, -1e9)
 30
 31        attention_weights = F.softmax(scores, dim=-1)
 32        output = torch.matmul(attention_weights, V)
 33        return output, attention_weights
 34
 35    # ν…ŒμŠ€νŠΈ
 36    batch, seq_len, d_model = 2, 5, 64
 37    Q = torch.randn(batch, seq_len, d_model)
 38    K = torch.randn(batch, seq_len, d_model)
 39    V = torch.randn(batch, seq_len, d_model)
 40
 41    output, weights = scaled_dot_product_attention(Q, K, V)
 42    print(f"μž…λ ₯ shape: Q={Q.shape}, K={K.shape}, V={V.shape}")
 43    print(f"좜λ ₯ shape: {output.shape}")
 44    print(f"Attention weights shape: {weights.shape}")
 45
 46
 47    # ============================================
 48    # 2. Multi-Head Attention
 49    # ============================================
 50    print("\n[2] Multi-Head Attention")
 51    print("-" * 40)
 52
 53    class MultiHeadAttention(nn.Module):
 54        def __init__(self, d_model, num_heads):
 55            super().__init__()
 56            self.d_model = d_model
 57            self.num_heads = num_heads
 58            self.d_k = d_model // num_heads
 59
 60            self.W_q = nn.Linear(d_model, d_model)
 61            self.W_k = nn.Linear(d_model, d_model)
 62            self.W_v = nn.Linear(d_model, d_model)
 63            self.W_o = nn.Linear(d_model, d_model)
 64
 65        def forward(self, x, mask=None):
 66            batch_size, seq_len, _ = x.shape
 67
 68            Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
 69            K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
 70            V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
 71
 72            attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
 73            attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
 74
 75            return self.W_o(attn_output)
 76
 77    mha = MultiHeadAttention(d_model=64, num_heads=8)
 78    x = torch.randn(2, 10, 64)
 79    output = mha(x)
 80    print(f"Multi-Head Attention: {x.shape} β†’ {output.shape}")
 81
 82
 83    # ============================================
 84    # 3. Causal Mask (GPT μŠ€νƒ€μΌ)
 85    # ============================================
 86    print("\n[3] Causal Mask")
 87    print("-" * 40)
 88
 89    def create_causal_mask(seq_len):
 90        """미래 토큰 λ§ˆμŠ€ν‚Ή"""
 91        mask = torch.tril(torch.ones(seq_len, seq_len))
 92        return mask
 93
 94    mask = create_causal_mask(5)
 95    print("Causal Mask (5x5):")
 96    print(mask)
 97
 98
 99    # ============================================
100    # 4. Positional Encoding
101    # ============================================
102    print("\n[4] Positional Encoding")
103    print("-" * 40)
104
105    class PositionalEncoding(nn.Module):
106        def __init__(self, d_model, max_len=5000):
107            super().__init__()
108            pe = torch.zeros(max_len, d_model)
109            position = torch.arange(0, max_len).unsqueeze(1).float()
110            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
111
112            pe[:, 0::2] = torch.sin(position * div_term)
113            pe[:, 1::2] = torch.cos(position * div_term)
114            self.register_buffer('pe', pe.unsqueeze(0))
115
116        def forward(self, x):
117            return x + self.pe[:, :x.size(1)]
118
119    pe = PositionalEncoding(d_model=64)
120    x = torch.randn(2, 10, 64)
121    output = pe(x)
122    print(f"Positional Encoding: {x.shape} β†’ {output.shape}")
123
124
125    # ============================================
126    # 5. Transformer Encoder Block
127    # ============================================
128    print("\n[5] Transformer Encoder Block")
129    print("-" * 40)
130
131    class TransformerBlock(nn.Module):
132        def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
133            super().__init__()
134            self.attention = MultiHeadAttention(d_model, num_heads)
135            self.norm1 = nn.LayerNorm(d_model)
136            self.norm2 = nn.LayerNorm(d_model)
137            self.ffn = nn.Sequential(
138                nn.Linear(d_model, d_ff),
139                nn.GELU(),
140                nn.Linear(d_ff, d_model)
141            )
142            self.dropout = nn.Dropout(dropout)
143
144        def forward(self, x, mask=None):
145            # Self-Attention + Residual
146            attn_out = self.attention(x, mask)
147            x = self.norm1(x + self.dropout(attn_out))
148
149            # FFN + Residual
150            ffn_out = self.ffn(x)
151            x = self.norm2(x + self.dropout(ffn_out))
152
153            return x
154
155    block = TransformerBlock(d_model=64, num_heads=8, d_ff=256)
156    x = torch.randn(2, 10, 64)
157    output = block(x)
158    print(f"Transformer Block: {x.shape} β†’ {output.shape}")
159
160
161    # ============================================
162    # 6. ν…μŠ€νŠΈ λΆ„λ₯˜ Transformer
163    # ============================================
164    print("\n[6] ν…μŠ€νŠΈ λΆ„λ₯˜ Transformer")
165    print("-" * 40)
166
167    class TransformerClassifier(nn.Module):
168        def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes):
169            super().__init__()
170            self.embedding = nn.Embedding(vocab_size, d_model)
171            self.pos_encoding = PositionalEncoding(d_model)
172            self.blocks = nn.ModuleList([
173                TransformerBlock(d_model, num_heads, d_model * 4)
174                for _ in range(num_layers)
175            ])
176            self.fc = nn.Linear(d_model, num_classes)
177
178        def forward(self, x):
179            x = self.embedding(x)
180            x = self.pos_encoding(x)
181            for block in self.blocks:
182                x = block(x)
183            x = x.mean(dim=1)  # 평균 풀링
184            return self.fc(x)
185
186    model = TransformerClassifier(
187        vocab_size=10000, d_model=128, num_heads=4, num_layers=2, num_classes=2
188    )
189    x = torch.randint(0, 10000, (4, 32))  # (batch, seq)
190    output = model(x)
191    print(f"μž…λ ₯: {x.shape}")
192    print(f"좜λ ₯: {output.shape}")
193    print(f"νŒŒλΌλ―Έν„° 수: {sum(p.numel() for p in model.parameters()):,}")
194
195
196    # ============================================
197    # 7. PyTorch λ‚΄μž₯ Transformer
198    # ============================================
199    print("\n[7] PyTorch λ‚΄μž₯ Transformer")
200    print("-" * 40)
201
202    encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
203    encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
204
205    x = torch.randn(32, 100, 512)  # (batch, seq, d_model)
206    output = encoder(x)
207    print(f"PyTorch Transformer: {x.shape} β†’ {output.shape}")
208
209
210    # ============================================
211    # 정리
212    # ============================================
213    print("\n" + "=" * 60)
214    print("Transformer 정리")
215    print("=" * 60)
216
217    summary = """
218핡심 κ΅¬μ„±μš”μ†Œ:
219    1. Self-Attention: Q @ K.T / sqrt(d_k) β†’ softmax β†’ @ V
220    2. Multi-Head: μ—¬λŸ¬ ν—€λ“œλ‘œ λΆ„ν•  ν›„ κ²°ν•©
221    3. Positional Encoding: μœ„μΉ˜ 정보 μΆ”κ°€
222    4. FFN: Linear β†’ GELU β†’ Linear
223    5. Residual + LayerNorm
224
225BERT vs GPT:
226    - BERT: μ–‘λ°©ν–₯ (인코더), 마슀크 μ—†μŒ
227    - GPT: 단방ν–₯ (디코더), Causal Mask
228"""
229    print(summary)
230
231except ImportError as e:
232    print(f"PyTorch λ―Έμ„€μΉ˜: {e}")