1"""
2PyTorch Low-Level GPT-2 ๊ตฌํ
3
4nanoGPT ์คํ์ผ์ ๊ฐ๊ฒฐํ ๊ตฌํ
5Pre-LayerNorm, Causal Attention, Weight Tying
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11import math
12from typing import Optional, Tuple
13from dataclasses import dataclass
14
15
16@dataclass
17class GPTConfig:
18 """GPT-2 ์ค์ """
19 vocab_size: int = 50257
20 block_size: int = 1024 # max sequence length
21 n_layer: int = 12
22 n_head: int = 12
23 n_embd: int = 768
24 dropout: float = 0.1
25 bias: bool = True
26
27
28class CausalSelfAttention(nn.Module):
29 """Causal Self-Attention (Masked Multi-Head Attention)"""
30
31 def __init__(self, config: GPTConfig):
32 super().__init__()
33 assert config.n_embd % config.n_head == 0
34
35 self.n_head = config.n_head
36 self.n_embd = config.n_embd
37 self.head_dim = config.n_embd // config.n_head
38 self.dropout = config.dropout
39
40 # Q, K, V๋ฅผ ํ๋์ projection์ผ๋ก (ํจ์จ์ฑ)
41 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
42 # Output projection
43 self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
44
45 # Dropout
46 self.attn_dropout = nn.Dropout(config.dropout)
47 self.resid_dropout = nn.Dropout(config.dropout)
48
49 # Causal mask (๋ฏธ๋ ํ ํฐ ์ฐธ์กฐ ์ฐจ๋จ)
50 # register_buffer: ํ์ตํ์ง ์์ง๋ง state_dict์ ํฌํจ
51 self.register_buffer(
52 "bias",
53 torch.tril(torch.ones(config.block_size, config.block_size))
54 .view(1, 1, config.block_size, config.block_size)
55 )
56
57 def forward(
58 self,
59 x: torch.Tensor,
60 use_cache: bool = False,
61 past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
62 ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
63 """
64 Args:
65 x: (batch, seq_len, n_embd)
66 use_cache: KV cache ์ฌ์ฉ ์ฌ๋ถ (์์ฑ ์)
67 past_kv: ์ด์ K, V ์บ์
68
69 Returns:
70 y: (batch, seq_len, n_embd)
71 present_kv: ํ์ฌ K, V (์บ์ฑ์ฉ)
72 """
73 B, T, C = x.shape
74
75 # Q, K, V ๊ณ์ฐ (ํ๋์ matmul๋ก)
76 qkv = self.c_attn(x)
77 q, k, v = qkv.split(self.n_embd, dim=2)
78
79 # Multi-head reshape: (B, T, n_head, head_dim) โ (B, n_head, T, head_dim)
80 q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
81 k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
82 v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
83
84 # KV Cache ์ฒ๋ฆฌ (์์ฑ ์ ํจ์จํ)
85 if past_kv is not None:
86 past_k, past_v = past_kv
87 k = torch.cat([past_k, k], dim=2)
88 v = torch.cat([past_v, v], dim=2)
89
90 present_kv = (k, v) if use_cache else None
91
92 # Attention scores
93 # (B, n_head, T, head_dim) @ (B, n_head, head_dim, T_kv) โ (B, n_head, T, T_kv)
94 att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
95
96 # Causal mask ์ ์ฉ
97 T_kv = k.size(2)
98 # ํ์ฌ ์์น๋ถํฐ ์์ํ๋ ๋ง์คํฌ (KV cache ๊ณ ๋ ค)
99 mask = self.bias[:, :, T_kv - T:T_kv, :T_kv]
100 att = att.masked_fill(mask == 0, float('-inf'))
101
102 # Softmax + Dropout
103 att = F.softmax(att, dim=-1)
104 att = self.attn_dropout(att)
105
106 # Apply attention to values
107 y = torch.matmul(att, v) # (B, n_head, T, head_dim)
108
109 # Reshape back: (B, T, n_embd)
110 y = y.transpose(1, 2).contiguous().view(B, T, C)
111
112 # Output projection + dropout
113 y = self.resid_dropout(self.c_proj(y))
114
115 return y, present_kv
116
117
118class MLP(nn.Module):
119 """Feed-Forward Network (GPT-2 ์คํ์ผ)"""
120
121 def __init__(self, config: GPTConfig):
122 super().__init__()
123 # 4x ํ์ฅ
124 self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
125 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
126 self.dropout = nn.Dropout(config.dropout)
127
128 def forward(self, x: torch.Tensor) -> torch.Tensor:
129 x = self.c_fc(x)
130 x = F.gelu(x, approximate='tanh') # GPT-2๋ tanh approximation
131 x = self.c_proj(x)
132 x = self.dropout(x)
133 return x
134
135
136class Block(nn.Module):
137 """Transformer Block (Pre-LN)"""
138
139 def __init__(self, config: GPTConfig):
140 super().__init__()
141 self.ln_1 = nn.LayerNorm(config.n_embd)
142 self.attn = CausalSelfAttention(config)
143 self.ln_2 = nn.LayerNorm(config.n_embd)
144 self.mlp = MLP(config)
145
146 def forward(
147 self,
148 x: torch.Tensor,
149 use_cache: bool = False,
150 past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
151 ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
152 # Pre-LN + Residual
153 attn_out, present_kv = self.attn(
154 self.ln_1(x), use_cache=use_cache, past_kv=past_kv
155 )
156 x = x + attn_out
157 x = x + self.mlp(self.ln_2(x))
158 return x, present_kv
159
160
161class GPT(nn.Module):
162 """GPT-2 Model"""
163
164 def __init__(self, config: GPTConfig):
165 super().__init__()
166 self.config = config
167
168 self.transformer = nn.ModuleDict({
169 'wte': nn.Embedding(config.vocab_size, config.n_embd), # token embedding
170 'wpe': nn.Embedding(config.block_size, config.n_embd), # position embedding
171 'drop': nn.Dropout(config.dropout),
172 'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
173 'ln_f': nn.LayerNorm(config.n_embd),
174 })
175
176 # LM Head (weight tying with wte)
177 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
178 # Weight tying
179 self.transformer['wte'].weight = self.lm_head.weight
180
181 # Initialize weights
182 self.apply(self._init_weights)
183
184 # Scale residual projections
185 for pn, p in self.named_parameters():
186 if pn.endswith('c_proj.weight'):
187 torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
188
189 def _init_weights(self, module):
190 if isinstance(module, nn.Linear):
191 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
192 if module.bias is not None:
193 torch.nn.init.zeros_(module.bias)
194 elif isinstance(module, nn.Embedding):
195 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
196
197 def forward(
198 self,
199 idx: torch.Tensor,
200 targets: Optional[torch.Tensor] = None,
201 use_cache: bool = False,
202 past_key_values: Optional[list] = None
203 ):
204 """
205 Args:
206 idx: (batch, seq_len) ํ ํฐ ์ธ๋ฑ์ค
207 targets: (batch, seq_len) ํ๊ฒ (ํ์ต ์)
208 use_cache: KV cache ์ฌ์ฉ
209 past_key_values: ์ด์ KV cache ๋ฆฌ์คํธ
210
211 Returns:
212 logits, loss, present_key_values
213 """
214 device = idx.device
215 B, T = idx.shape
216
217 # Position IDs
218 if past_key_values is not None:
219 past_length = past_key_values[0][0].size(2)
220 pos = torch.arange(past_length, past_length + T, device=device)
221 else:
222 pos = torch.arange(0, T, device=device)
223
224 # Embeddings
225 tok_emb = self.transformer['wte'](idx) # (B, T, n_embd)
226 pos_emb = self.transformer['wpe'](pos) # (T, n_embd)
227 x = self.transformer['drop'](tok_emb + pos_emb)
228
229 # Transformer blocks
230 present_key_values = [] if use_cache else None
231
232 for i, block in enumerate(self.transformer['h']):
233 past_kv = past_key_values[i] if past_key_values is not None else None
234 x, present_kv = block(x, use_cache=use_cache, past_kv=past_kv)
235 if use_cache:
236 present_key_values.append(present_kv)
237
238 # Final layer norm
239 x = self.transformer['ln_f'](x)
240
241 # LM Head
242 logits = self.lm_head(x) # (B, T, vocab_size)
243
244 # Loss
245 loss = None
246 if targets is not None:
247 loss = F.cross_entropy(
248 logits.view(-1, logits.size(-1)),
249 targets.view(-1),
250 ignore_index=-100
251 )
252
253 return {
254 'logits': logits,
255 'loss': loss,
256 'past_key_values': present_key_values
257 }
258
259 @torch.no_grad()
260 def generate(
261 self,
262 idx: torch.Tensor,
263 max_new_tokens: int,
264 temperature: float = 1.0,
265 top_k: Optional[int] = None,
266 top_p: Optional[float] = None,
267 use_cache: bool = True
268 ) -> torch.Tensor:
269 """
270 ํ
์คํธ ์์ฑ
271
272 Args:
273 idx: (batch, seq_len) ์์ ํ ํฐ
274 max_new_tokens: ์์ฑํ ํ ํฐ ์
275 temperature: ์ํ๋ง ์จ๋
276 top_k: Top-K ์ํ๋ง
277 top_p: Nucleus (Top-P) ์ํ๋ง
278 use_cache: KV cache ์ฌ์ฉ
279
280 Returns:
281 idx: (batch, seq_len + max_new_tokens)
282 """
283 past_key_values = None
284
285 for _ in range(max_new_tokens):
286 # ์ปจํ
์คํธ ์๋ฅด๊ธฐ (block_size ์ด๊ณผ ๋ฐฉ์ง)
287 if past_key_values is None:
288 idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
289 else:
290 idx_cond = idx[:, -1:] # ๋ง์ง๋ง ํ ํฐ๋ง (์บ์ ์ฌ์ฉ ์)
291
292 # Forward
293 outputs = self(idx_cond, use_cache=use_cache, past_key_values=past_key_values)
294 logits = outputs['logits'][:, -1, :] # ๋ง์ง๋ง ์์น
295 past_key_values = outputs['past_key_values']
296
297 # Temperature
298 logits = logits / temperature
299
300 # Top-K ํํฐ๋ง
301 if top_k is not None:
302 v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
303 logits[logits < v[:, [-1]]] = float('-inf')
304
305 # Top-P (Nucleus) ํํฐ๋ง
306 if top_p is not None:
307 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
308 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
309
310 # Top-P ์ด๊ณผํ๋ ํ ํฐ ์ ๊ฑฐ
311 sorted_indices_to_remove = cumulative_probs > top_p
312 sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
313 sorted_indices_to_remove[:, 0] = 0
314
315 indices_to_remove = sorted_indices_to_remove.scatter(
316 1, sorted_indices, sorted_indices_to_remove
317 )
318 logits[indices_to_remove] = float('-inf')
319
320 # ์ํ๋ง
321 probs = F.softmax(logits, dim=-1)
322 idx_next = torch.multinomial(probs, num_samples=1)
323
324 # ๊ฒฐ๊ณผ์ ์ถ๊ฐ
325 idx = torch.cat([idx, idx_next], dim=1)
326
327 return idx
328
329
330# ํ
์คํธ
331if __name__ == "__main__":
332 print("=== GPT-2 Low-Level Implementation Test ===\n")
333
334 # GPT-2 Small ์ค์
335 config = GPTConfig(
336 vocab_size=50257,
337 block_size=1024,
338 n_layer=12,
339 n_head=12,
340 n_embd=768,
341 dropout=0.1
342 )
343
344 # ๋ชจ๋ธ ์์ฑ
345 model = GPT(config)
346
347 # ํ๋ผ๋ฏธํฐ ์
348 total_params = sum(p.numel() for p in model.parameters())
349 print(f"Total parameters: {total_params:,}")
350 print(f"Expected ~117M for GPT-2 Small\n")
351
352 # ํ
์คํธ ์
๋ ฅ
353 batch_size, seq_len = 2, 64
354 idx = torch.randint(0, config.vocab_size, (batch_size, seq_len))
355 targets = torch.randint(0, config.vocab_size, (batch_size, seq_len))
356
357 # Forward
358 outputs = model(idx, targets=targets)
359
360 print("Forward pass:")
361 print(f" Logits shape: {outputs['logits'].shape}")
362 print(f" Loss: {outputs['loss'].item():.4f}")
363
364 # ์์ฑ ํ
์คํธ
365 print("\n=== Generation Test ===")
366 start_tokens = torch.tensor([[50256]]) # <|endoftext|>
367
368 generated = model.generate(
369 start_tokens,
370 max_new_tokens=20,
371 temperature=0.8,
372 top_k=50
373 )
374
375 print(f"Generated shape: {generated.shape}")
376 print(f"Generated tokens: {generated[0].tolist()[:25]}...")
377
378 # KV Cache ํ
์คํธ
379 print("\n=== KV Cache Test ===")
380 import time
381
382 # Without cache
383 torch.manual_seed(42)
384 start = time.time()
385 gen_no_cache = model.generate(start_tokens, max_new_tokens=50, use_cache=False)
386 time_no_cache = time.time() - start
387
388 # With cache
389 torch.manual_seed(42)
390 start = time.time()
391 gen_with_cache = model.generate(start_tokens, max_new_tokens=50, use_cache=True)
392 time_with_cache = time.time() - start
393
394 print(f"Without cache: {time_no_cache:.3f}s")
395 print(f"With cache: {time_with_cache:.3f}s")
396 print(f"Speedup: {time_no_cache / time_with_cache:.2f}x")
397
398 # ๊ฒฐ๊ณผ ์ผ์น ํ์ธ
399 assert torch.equal(gen_no_cache, gen_with_cache), "Cache results should match!"
400 print("Cache results match!")
401
402 print("\nAll tests passed!")