gpt_lowlevel.py

Download
python 403 lines 12.7 KB
  1"""
  2PyTorch Low-Level GPT-2 ๊ตฌํ˜„
  3
  4nanoGPT ์Šคํƒ€์ผ์˜ ๊ฐ„๊ฒฐํ•œ ๊ตฌํ˜„
  5Pre-LayerNorm, Causal Attention, Weight Tying
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11import math
 12from typing import Optional, Tuple
 13from dataclasses import dataclass
 14
 15
 16@dataclass
 17class GPTConfig:
 18    """GPT-2 ์„ค์ •"""
 19    vocab_size: int = 50257
 20    block_size: int = 1024  # max sequence length
 21    n_layer: int = 12
 22    n_head: int = 12
 23    n_embd: int = 768
 24    dropout: float = 0.1
 25    bias: bool = True
 26
 27
 28class CausalSelfAttention(nn.Module):
 29    """Causal Self-Attention (Masked Multi-Head Attention)"""
 30
 31    def __init__(self, config: GPTConfig):
 32        super().__init__()
 33        assert config.n_embd % config.n_head == 0
 34
 35        self.n_head = config.n_head
 36        self.n_embd = config.n_embd
 37        self.head_dim = config.n_embd // config.n_head
 38        self.dropout = config.dropout
 39
 40        # Q, K, V๋ฅผ ํ•˜๋‚˜์˜ projection์œผ๋กœ (ํšจ์œจ์„ฑ)
 41        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
 42        # Output projection
 43        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
 44
 45        # Dropout
 46        self.attn_dropout = nn.Dropout(config.dropout)
 47        self.resid_dropout = nn.Dropout(config.dropout)
 48
 49        # Causal mask (๋ฏธ๋ž˜ ํ† ํฐ ์ฐธ์กฐ ์ฐจ๋‹จ)
 50        # register_buffer: ํ•™์Šตํ•˜์ง€ ์•Š์ง€๋งŒ state_dict์— ํฌํ•จ
 51        self.register_buffer(
 52            "bias",
 53            torch.tril(torch.ones(config.block_size, config.block_size))
 54                .view(1, 1, config.block_size, config.block_size)
 55        )
 56
 57    def forward(
 58        self,
 59        x: torch.Tensor,
 60        use_cache: bool = False,
 61        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
 62    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 63        """
 64        Args:
 65            x: (batch, seq_len, n_embd)
 66            use_cache: KV cache ์‚ฌ์šฉ ์—ฌ๋ถ€ (์ƒ์„ฑ ์‹œ)
 67            past_kv: ์ด์ „ K, V ์บ์‹œ
 68
 69        Returns:
 70            y: (batch, seq_len, n_embd)
 71            present_kv: ํ˜„์žฌ K, V (์บ์‹ฑ์šฉ)
 72        """
 73        B, T, C = x.shape
 74
 75        # Q, K, V ๊ณ„์‚ฐ (ํ•˜๋‚˜์˜ matmul๋กœ)
 76        qkv = self.c_attn(x)
 77        q, k, v = qkv.split(self.n_embd, dim=2)
 78
 79        # Multi-head reshape: (B, T, n_head, head_dim) โ†’ (B, n_head, T, head_dim)
 80        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
 81        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
 82        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
 83
 84        # KV Cache ์ฒ˜๋ฆฌ (์ƒ์„ฑ ์‹œ ํšจ์œจํ™”)
 85        if past_kv is not None:
 86            past_k, past_v = past_kv
 87            k = torch.cat([past_k, k], dim=2)
 88            v = torch.cat([past_v, v], dim=2)
 89
 90        present_kv = (k, v) if use_cache else None
 91
 92        # Attention scores
 93        # (B, n_head, T, head_dim) @ (B, n_head, head_dim, T_kv) โ†’ (B, n_head, T, T_kv)
 94        att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
 95
 96        # Causal mask ์ ์šฉ
 97        T_kv = k.size(2)
 98        # ํ˜„์žฌ ์œ„์น˜๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋งˆ์Šคํฌ (KV cache ๊ณ ๋ ค)
 99        mask = self.bias[:, :, T_kv - T:T_kv, :T_kv]
100        att = att.masked_fill(mask == 0, float('-inf'))
101
102        # Softmax + Dropout
103        att = F.softmax(att, dim=-1)
104        att = self.attn_dropout(att)
105
106        # Apply attention to values
107        y = torch.matmul(att, v)  # (B, n_head, T, head_dim)
108
109        # Reshape back: (B, T, n_embd)
110        y = y.transpose(1, 2).contiguous().view(B, T, C)
111
112        # Output projection + dropout
113        y = self.resid_dropout(self.c_proj(y))
114
115        return y, present_kv
116
117
118class MLP(nn.Module):
119    """Feed-Forward Network (GPT-2 ์Šคํƒ€์ผ)"""
120
121    def __init__(self, config: GPTConfig):
122        super().__init__()
123        # 4x ํ™•์žฅ
124        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
125        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
126        self.dropout = nn.Dropout(config.dropout)
127
128    def forward(self, x: torch.Tensor) -> torch.Tensor:
129        x = self.c_fc(x)
130        x = F.gelu(x, approximate='tanh')  # GPT-2๋Š” tanh approximation
131        x = self.c_proj(x)
132        x = self.dropout(x)
133        return x
134
135
136class Block(nn.Module):
137    """Transformer Block (Pre-LN)"""
138
139    def __init__(self, config: GPTConfig):
140        super().__init__()
141        self.ln_1 = nn.LayerNorm(config.n_embd)
142        self.attn = CausalSelfAttention(config)
143        self.ln_2 = nn.LayerNorm(config.n_embd)
144        self.mlp = MLP(config)
145
146    def forward(
147        self,
148        x: torch.Tensor,
149        use_cache: bool = False,
150        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
151    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
152        # Pre-LN + Residual
153        attn_out, present_kv = self.attn(
154            self.ln_1(x), use_cache=use_cache, past_kv=past_kv
155        )
156        x = x + attn_out
157        x = x + self.mlp(self.ln_2(x))
158        return x, present_kv
159
160
161class GPT(nn.Module):
162    """GPT-2 Model"""
163
164    def __init__(self, config: GPTConfig):
165        super().__init__()
166        self.config = config
167
168        self.transformer = nn.ModuleDict({
169            'wte': nn.Embedding(config.vocab_size, config.n_embd),  # token embedding
170            'wpe': nn.Embedding(config.block_size, config.n_embd),  # position embedding
171            'drop': nn.Dropout(config.dropout),
172            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
173            'ln_f': nn.LayerNorm(config.n_embd),
174        })
175
176        # LM Head (weight tying with wte)
177        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
178        # Weight tying
179        self.transformer['wte'].weight = self.lm_head.weight
180
181        # Initialize weights
182        self.apply(self._init_weights)
183
184        # Scale residual projections
185        for pn, p in self.named_parameters():
186            if pn.endswith('c_proj.weight'):
187                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
188
189    def _init_weights(self, module):
190        if isinstance(module, nn.Linear):
191            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
192            if module.bias is not None:
193                torch.nn.init.zeros_(module.bias)
194        elif isinstance(module, nn.Embedding):
195            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
196
197    def forward(
198        self,
199        idx: torch.Tensor,
200        targets: Optional[torch.Tensor] = None,
201        use_cache: bool = False,
202        past_key_values: Optional[list] = None
203    ):
204        """
205        Args:
206            idx: (batch, seq_len) ํ† ํฐ ์ธ๋ฑ์Šค
207            targets: (batch, seq_len) ํƒ€๊ฒŸ (ํ•™์Šต ์‹œ)
208            use_cache: KV cache ์‚ฌ์šฉ
209            past_key_values: ์ด์ „ KV cache ๋ฆฌ์ŠคํŠธ
210
211        Returns:
212            logits, loss, present_key_values
213        """
214        device = idx.device
215        B, T = idx.shape
216
217        # Position IDs
218        if past_key_values is not None:
219            past_length = past_key_values[0][0].size(2)
220            pos = torch.arange(past_length, past_length + T, device=device)
221        else:
222            pos = torch.arange(0, T, device=device)
223
224        # Embeddings
225        tok_emb = self.transformer['wte'](idx)  # (B, T, n_embd)
226        pos_emb = self.transformer['wpe'](pos)  # (T, n_embd)
227        x = self.transformer['drop'](tok_emb + pos_emb)
228
229        # Transformer blocks
230        present_key_values = [] if use_cache else None
231
232        for i, block in enumerate(self.transformer['h']):
233            past_kv = past_key_values[i] if past_key_values is not None else None
234            x, present_kv = block(x, use_cache=use_cache, past_kv=past_kv)
235            if use_cache:
236                present_key_values.append(present_kv)
237
238        # Final layer norm
239        x = self.transformer['ln_f'](x)
240
241        # LM Head
242        logits = self.lm_head(x)  # (B, T, vocab_size)
243
244        # Loss
245        loss = None
246        if targets is not None:
247            loss = F.cross_entropy(
248                logits.view(-1, logits.size(-1)),
249                targets.view(-1),
250                ignore_index=-100
251            )
252
253        return {
254            'logits': logits,
255            'loss': loss,
256            'past_key_values': present_key_values
257        }
258
259    @torch.no_grad()
260    def generate(
261        self,
262        idx: torch.Tensor,
263        max_new_tokens: int,
264        temperature: float = 1.0,
265        top_k: Optional[int] = None,
266        top_p: Optional[float] = None,
267        use_cache: bool = True
268    ) -> torch.Tensor:
269        """
270        ํ…์ŠคํŠธ ์ƒ์„ฑ
271
272        Args:
273            idx: (batch, seq_len) ์‹œ์ž‘ ํ† ํฐ
274            max_new_tokens: ์ƒ์„ฑํ•  ํ† ํฐ ์ˆ˜
275            temperature: ์ƒ˜ํ”Œ๋ง ์˜จ๋„
276            top_k: Top-K ์ƒ˜ํ”Œ๋ง
277            top_p: Nucleus (Top-P) ์ƒ˜ํ”Œ๋ง
278            use_cache: KV cache ์‚ฌ์šฉ
279
280        Returns:
281            idx: (batch, seq_len + max_new_tokens)
282        """
283        past_key_values = None
284
285        for _ in range(max_new_tokens):
286            # ์ปจํ…์ŠคํŠธ ์ž๋ฅด๊ธฐ (block_size ์ดˆ๊ณผ ๋ฐฉ์ง€)
287            if past_key_values is None:
288                idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
289            else:
290                idx_cond = idx[:, -1:]  # ๋งˆ์ง€๋ง‰ ํ† ํฐ๋งŒ (์บ์‹œ ์‚ฌ์šฉ ์‹œ)
291
292            # Forward
293            outputs = self(idx_cond, use_cache=use_cache, past_key_values=past_key_values)
294            logits = outputs['logits'][:, -1, :]  # ๋งˆ์ง€๋ง‰ ์œ„์น˜
295            past_key_values = outputs['past_key_values']
296
297            # Temperature
298            logits = logits / temperature
299
300            # Top-K ํ•„ํ„ฐ๋ง
301            if top_k is not None:
302                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
303                logits[logits < v[:, [-1]]] = float('-inf')
304
305            # Top-P (Nucleus) ํ•„ํ„ฐ๋ง
306            if top_p is not None:
307                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
308                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
309
310                # Top-P ์ดˆ๊ณผํ•˜๋Š” ํ† ํฐ ์ œ๊ฑฐ
311                sorted_indices_to_remove = cumulative_probs > top_p
312                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
313                sorted_indices_to_remove[:, 0] = 0
314
315                indices_to_remove = sorted_indices_to_remove.scatter(
316                    1, sorted_indices, sorted_indices_to_remove
317                )
318                logits[indices_to_remove] = float('-inf')
319
320            # ์ƒ˜ํ”Œ๋ง
321            probs = F.softmax(logits, dim=-1)
322            idx_next = torch.multinomial(probs, num_samples=1)
323
324            # ๊ฒฐ๊ณผ์— ์ถ”๊ฐ€
325            idx = torch.cat([idx, idx_next], dim=1)
326
327        return idx
328
329
330# ํ…Œ์ŠคํŠธ
331if __name__ == "__main__":
332    print("=== GPT-2 Low-Level Implementation Test ===\n")
333
334    # GPT-2 Small ์„ค์ •
335    config = GPTConfig(
336        vocab_size=50257,
337        block_size=1024,
338        n_layer=12,
339        n_head=12,
340        n_embd=768,
341        dropout=0.1
342    )
343
344    # ๋ชจ๋ธ ์ƒ์„ฑ
345    model = GPT(config)
346
347    # ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
348    total_params = sum(p.numel() for p in model.parameters())
349    print(f"Total parameters: {total_params:,}")
350    print(f"Expected ~117M for GPT-2 Small\n")
351
352    # ํ…Œ์ŠคํŠธ ์ž…๋ ฅ
353    batch_size, seq_len = 2, 64
354    idx = torch.randint(0, config.vocab_size, (batch_size, seq_len))
355    targets = torch.randint(0, config.vocab_size, (batch_size, seq_len))
356
357    # Forward
358    outputs = model(idx, targets=targets)
359
360    print("Forward pass:")
361    print(f"  Logits shape: {outputs['logits'].shape}")
362    print(f"  Loss: {outputs['loss'].item():.4f}")
363
364    # ์ƒ์„ฑ ํ…Œ์ŠคํŠธ
365    print("\n=== Generation Test ===")
366    start_tokens = torch.tensor([[50256]])  # <|endoftext|>
367
368    generated = model.generate(
369        start_tokens,
370        max_new_tokens=20,
371        temperature=0.8,
372        top_k=50
373    )
374
375    print(f"Generated shape: {generated.shape}")
376    print(f"Generated tokens: {generated[0].tolist()[:25]}...")
377
378    # KV Cache ํ…Œ์ŠคํŠธ
379    print("\n=== KV Cache Test ===")
380    import time
381
382    # Without cache
383    torch.manual_seed(42)
384    start = time.time()
385    gen_no_cache = model.generate(start_tokens, max_new_tokens=50, use_cache=False)
386    time_no_cache = time.time() - start
387
388    # With cache
389    torch.manual_seed(42)
390    start = time.time()
391    gen_with_cache = model.generate(start_tokens, max_new_tokens=50, use_cache=True)
392    time_with_cache = time.time() - start
393
394    print(f"Without cache: {time_no_cache:.3f}s")
395    print(f"With cache: {time_with_cache:.3f}s")
396    print(f"Speedup: {time_no_cache / time_with_cache:.2f}x")
397
398    # ๊ฒฐ๊ณผ ์ผ์น˜ ํ™•์ธ
399    assert torch.equal(gen_no_cache, gen_with_cache), "Cache results should match!"
400    print("Cache results match!")
401
402    print("\nAll tests passed!")