1"""
2PyTorch Low-Level BERT ๊ตฌํ
3
4nn.TransformerEncoder ๋ฏธ์ฌ์ฉ
5F.linear, F.layer_norm ๋ฑ ๊ธฐ๋ณธ ์ฐ์ฐ๋ง ์ฌ์ฉ
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11import math
12from typing import Optional, Tuple
13
14
15class BertEmbeddings(nn.Module):
16 """
17 BERT Embeddings = Token + Segment + Position
18
19 Token: ๋จ์ด ์๋ฏธ
20 Segment: ๋ฌธ์ฅ A/B ๊ตฌ๋ถ
21 Position: ์์น ์ ๋ณด (ํ์ต ๊ฐ๋ฅ)
22 """
23
24 def __init__(
25 self,
26 vocab_size: int,
27 hidden_size: int,
28 max_position: int = 512,
29 type_vocab_size: int = 2, # ๋ฌธ์ฅ A, B
30 dropout: float = 0.1
31 ):
32 super().__init__()
33 self.hidden_size = hidden_size
34
35 # Embedding tables (nn.Embedding ์ฌ์ฉํ์ง๋ง ๊ฐ๋
์ ์ผ๋ก lookup)
36 self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
37 self.position_embeddings = nn.Embedding(max_position, hidden_size)
38 self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
39
40 # Layer Norm + Dropout
41 self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
42 self.dropout = nn.Dropout(dropout)
43
44 def forward(
45 self,
46 input_ids: torch.Tensor,
47 token_type_ids: Optional[torch.Tensor] = None,
48 position_ids: Optional[torch.Tensor] = None
49 ) -> torch.Tensor:
50 """
51 Args:
52 input_ids: (batch, seq_len) ํ ํฐ ID
53 token_type_ids: (batch, seq_len) ์ธ๊ทธ๋จผํธ ID (0 or 1)
54 position_ids: (batch, seq_len) ์์น ID
55
56 Returns:
57 embeddings: (batch, seq_len, hidden_size)
58 """
59 batch_size, seq_len = input_ids.shape
60
61 # ๊ธฐ๋ณธ๊ฐ ์ค์
62 if token_type_ids is None:
63 token_type_ids = torch.zeros_like(input_ids)
64
65 if position_ids is None:
66 position_ids = torch.arange(seq_len, device=input_ids.device)
67 position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
68
69 # ์ธ ๊ฐ์ง ์๋ฒ ๋ฉ ํฉ์ฐ
70 word_emb = self.word_embeddings(input_ids)
71 position_emb = self.position_embeddings(position_ids)
72 token_type_emb = self.token_type_embeddings(token_type_ids)
73
74 embeddings = word_emb + position_emb + token_type_emb
75
76 # Layer Norm + Dropout
77 embeddings = self.layer_norm(embeddings)
78 embeddings = self.dropout(embeddings)
79
80 return embeddings
81
82
83class BertSelfAttention(nn.Module):
84 """Multi-Head Self-Attention (Low-Level)"""
85
86 def __init__(
87 self,
88 hidden_size: int,
89 num_heads: int,
90 dropout: float = 0.1
91 ):
92 super().__init__()
93 assert hidden_size % num_heads == 0
94
95 self.num_heads = num_heads
96 self.head_dim = hidden_size // num_heads
97 self.scale = math.sqrt(self.head_dim)
98
99 # Q, K, V projections (nn.Linear ๋์ ํ๋ผ๋ฏธํฐ ์ง์ ๊ด๋ฆฌ ๊ฐ๋ฅ)
100 self.query = nn.Linear(hidden_size, hidden_size)
101 self.key = nn.Linear(hidden_size, hidden_size)
102 self.value = nn.Linear(hidden_size, hidden_size)
103
104 self.dropout = nn.Dropout(dropout)
105
106 def forward(
107 self,
108 hidden_states: torch.Tensor,
109 attention_mask: Optional[torch.Tensor] = None
110 ) -> Tuple[torch.Tensor, torch.Tensor]:
111 """
112 Args:
113 hidden_states: (batch, seq_len, hidden_size)
114 attention_mask: (batch, 1, 1, seq_len) ๋๋ (batch, seq_len)
115
116 Returns:
117 context: (batch, seq_len, hidden_size)
118 attention_weights: (batch, num_heads, seq_len, seq_len)
119 """
120 batch_size, seq_len, _ = hidden_states.shape
121
122 # Q, K, V ๊ณ์ฐ
123 Q = self.query(hidden_states)
124 K = self.key(hidden_states)
125 V = self.value(hidden_states)
126
127 # Multi-head reshape: (batch, seq, hidden) โ (batch, heads, seq, head_dim)
128 Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
129 K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
130 V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
131
132 # Attention scores: (batch, heads, seq, seq)
133 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
134
135 # Attention mask ์ ์ฉ
136 if attention_mask is not None:
137 # (batch, seq) โ (batch, 1, 1, seq)
138 if attention_mask.dim() == 2:
139 attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
140 # 0 โ -inf, 1 โ 0
141 attention_mask = (1.0 - attention_mask) * -10000.0
142 scores = scores + attention_mask
143
144 # Softmax + Dropout
145 attention_weights = F.softmax(scores, dim=-1)
146 attention_weights = self.dropout(attention_weights)
147
148 # Context: (batch, heads, seq, head_dim)
149 context = torch.matmul(attention_weights, V)
150
151 # Reshape back: (batch, seq, hidden)
152 context = context.transpose(1, 2).contiguous()
153 context = context.view(batch_size, seq_len, -1)
154
155 return context, attention_weights
156
157
158class BertSelfOutput(nn.Module):
159 """Attention Output (projection + residual + layer norm)"""
160
161 def __init__(self, hidden_size: int, dropout: float = 0.1):
162 super().__init__()
163 self.dense = nn.Linear(hidden_size, hidden_size)
164 self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
165 self.dropout = nn.Dropout(dropout)
166
167 def forward(
168 self,
169 hidden_states: torch.Tensor,
170 input_tensor: torch.Tensor
171 ) -> torch.Tensor:
172 """
173 Args:
174 hidden_states: attention ์ถ๋ ฅ
175 input_tensor: residual connection์ฉ ์๋ณธ ์
๋ ฅ
176 """
177 hidden_states = self.dense(hidden_states)
178 hidden_states = self.dropout(hidden_states)
179 # Residual + Layer Norm
180 hidden_states = self.layer_norm(hidden_states + input_tensor)
181 return hidden_states
182
183
184class BertIntermediate(nn.Module):
185 """Feed-Forward ์ฒซ ๋ฒ์งธ ์ธต (ํ์ฅ)"""
186
187 def __init__(self, hidden_size: int, intermediate_size: int):
188 super().__init__()
189 self.dense = nn.Linear(hidden_size, intermediate_size)
190 # BERT๋ GELU ์ฌ์ฉ
191
192 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
193 hidden_states = self.dense(hidden_states)
194 hidden_states = F.gelu(hidden_states)
195 return hidden_states
196
197
198class BertOutput(nn.Module):
199 """Feed-Forward ๋ ๋ฒ์งธ ์ธต (์ถ์) + Residual"""
200
201 def __init__(
202 self,
203 hidden_size: int,
204 intermediate_size: int,
205 dropout: float = 0.1
206 ):
207 super().__init__()
208 self.dense = nn.Linear(intermediate_size, hidden_size)
209 self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
210 self.dropout = nn.Dropout(dropout)
211
212 def forward(
213 self,
214 hidden_states: torch.Tensor,
215 input_tensor: torch.Tensor
216 ) -> torch.Tensor:
217 hidden_states = self.dense(hidden_states)
218 hidden_states = self.dropout(hidden_states)
219 hidden_states = self.layer_norm(hidden_states + input_tensor)
220 return hidden_states
221
222
223class BertLayer(nn.Module):
224 """Single BERT Encoder Layer"""
225
226 def __init__(
227 self,
228 hidden_size: int,
229 num_heads: int,
230 intermediate_size: int,
231 dropout: float = 0.1
232 ):
233 super().__init__()
234 # Self-Attention
235 self.attention = BertSelfAttention(hidden_size, num_heads, dropout)
236 self.attention_output = BertSelfOutput(hidden_size, dropout)
237
238 # Feed-Forward
239 self.intermediate = BertIntermediate(hidden_size, intermediate_size)
240 self.output = BertOutput(hidden_size, intermediate_size, dropout)
241
242 def forward(
243 self,
244 hidden_states: torch.Tensor,
245 attention_mask: Optional[torch.Tensor] = None
246 ) -> Tuple[torch.Tensor, torch.Tensor]:
247 # Self-Attention
248 attention_output, attention_weights = self.attention(
249 hidden_states, attention_mask
250 )
251 attention_output = self.attention_output(attention_output, hidden_states)
252
253 # Feed-Forward
254 intermediate_output = self.intermediate(attention_output)
255 layer_output = self.output(intermediate_output, attention_output)
256
257 return layer_output, attention_weights
258
259
260class BertEncoder(nn.Module):
261 """BERT Encoder (stacked layers)"""
262
263 def __init__(
264 self,
265 num_layers: int,
266 hidden_size: int,
267 num_heads: int,
268 intermediate_size: int,
269 dropout: float = 0.1
270 ):
271 super().__init__()
272 self.layers = nn.ModuleList([
273 BertLayer(hidden_size, num_heads, intermediate_size, dropout)
274 for _ in range(num_layers)
275 ])
276
277 def forward(
278 self,
279 hidden_states: torch.Tensor,
280 attention_mask: Optional[torch.Tensor] = None,
281 output_attentions: bool = False
282 ) -> Tuple[torch.Tensor, Optional[list]]:
283 all_attentions = [] if output_attentions else None
284
285 for layer in self.layers:
286 hidden_states, attention_weights = layer(hidden_states, attention_mask)
287 if output_attentions:
288 all_attentions.append(attention_weights)
289
290 return hidden_states, all_attentions
291
292
293class BertPooler(nn.Module):
294 """[CLS] ํ ํฐ ํ๋ง"""
295
296 def __init__(self, hidden_size: int):
297 super().__init__()
298 self.dense = nn.Linear(hidden_size, hidden_size)
299
300 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
301 """
302 Args:
303 hidden_states: (batch, seq_len, hidden_size)
304
305 Returns:
306 pooled: (batch, hidden_size) - [CLS] ํ ํฐ์ ํํ
307 """
308 # [CLS] ํ ํฐ (์ฒซ ๋ฒ์งธ ํ ํฐ)
309 cls_token = hidden_states[:, 0]
310 pooled = self.dense(cls_token)
311 pooled = torch.tanh(pooled)
312 return pooled
313
314
315class BertModel(nn.Module):
316 """BERT Base Model"""
317
318 def __init__(
319 self,
320 vocab_size: int = 30522,
321 hidden_size: int = 768,
322 num_layers: int = 12,
323 num_heads: int = 12,
324 intermediate_size: int = 3072,
325 max_position: int = 512,
326 type_vocab_size: int = 2,
327 dropout: float = 0.1
328 ):
329 super().__init__()
330 self.hidden_size = hidden_size
331
332 self.embeddings = BertEmbeddings(
333 vocab_size, hidden_size, max_position, type_vocab_size, dropout
334 )
335 self.encoder = BertEncoder(
336 num_layers, hidden_size, num_heads, intermediate_size, dropout
337 )
338 self.pooler = BertPooler(hidden_size)
339
340 def forward(
341 self,
342 input_ids: torch.Tensor,
343 attention_mask: Optional[torch.Tensor] = None,
344 token_type_ids: Optional[torch.Tensor] = None,
345 output_attentions: bool = False
346 ):
347 """
348 Args:
349 input_ids: (batch, seq_len)
350 attention_mask: (batch, seq_len) - 1 for valid, 0 for padding
351 token_type_ids: (batch, seq_len) - 0 for sent A, 1 for sent B
352
353 Returns:
354 last_hidden_state: (batch, seq_len, hidden_size)
355 pooler_output: (batch, hidden_size)
356 attentions: optional list of attention weights
357 """
358 # Embeddings
359 embeddings = self.embeddings(
360 input_ids, token_type_ids=token_type_ids
361 )
362
363 # Encoder
364 encoder_output, attentions = self.encoder(
365 embeddings, attention_mask, output_attentions
366 )
367
368 # Pooler
369 pooled_output = self.pooler(encoder_output)
370
371 return {
372 'last_hidden_state': encoder_output,
373 'pooler_output': pooled_output,
374 'attentions': attentions
375 }
376
377
378class BertForMaskedLM(nn.Module):
379 """BERT for Masked Language Modeling"""
380
381 def __init__(self, config: dict):
382 super().__init__()
383 self.bert = BertModel(**config)
384
385 # MLM Head
386 self.cls = nn.Sequential(
387 nn.Linear(config['hidden_size'], config['hidden_size']),
388 nn.GELU(),
389 nn.LayerNorm(config['hidden_size'], eps=1e-12),
390 nn.Linear(config['hidden_size'], config['vocab_size'])
391 )
392
393 def forward(
394 self,
395 input_ids: torch.Tensor,
396 attention_mask: Optional[torch.Tensor] = None,
397 token_type_ids: Optional[torch.Tensor] = None,
398 labels: Optional[torch.Tensor] = None
399 ):
400 outputs = self.bert(input_ids, attention_mask, token_type_ids)
401 hidden_states = outputs['last_hidden_state']
402
403 # MLM predictions
404 prediction_scores = self.cls(hidden_states)
405
406 loss = None
407 if labels is not None:
408 loss = F.cross_entropy(
409 prediction_scores.view(-1, prediction_scores.size(-1)),
410 labels.view(-1),
411 ignore_index=-100
412 )
413
414 return {
415 'loss': loss,
416 'logits': prediction_scores,
417 'hidden_states': hidden_states
418 }
419
420
421class BertForSequenceClassification(nn.Module):
422 """BERT for Sequence Classification"""
423
424 def __init__(self, config: dict, num_labels: int):
425 super().__init__()
426 self.bert = BertModel(**config)
427 self.classifier = nn.Linear(config['hidden_size'], num_labels)
428 self.dropout = nn.Dropout(config.get('dropout', 0.1))
429
430 def forward(
431 self,
432 input_ids: torch.Tensor,
433 attention_mask: Optional[torch.Tensor] = None,
434 token_type_ids: Optional[torch.Tensor] = None,
435 labels: Optional[torch.Tensor] = None
436 ):
437 outputs = self.bert(input_ids, attention_mask, token_type_ids)
438 pooled_output = outputs['pooler_output']
439
440 pooled_output = self.dropout(pooled_output)
441 logits = self.classifier(pooled_output)
442
443 loss = None
444 if labels is not None:
445 loss = F.cross_entropy(logits, labels)
446
447 return {
448 'loss': loss,
449 'logits': logits
450 }
451
452
453# ํ
์คํธ
454if __name__ == "__main__":
455 print("=== BERT Low-Level Implementation Test ===\n")
456
457 # ์ค์
458 config = {
459 'vocab_size': 30522,
460 'hidden_size': 768,
461 'num_layers': 12,
462 'num_heads': 12,
463 'intermediate_size': 3072,
464 'max_position': 512,
465 'dropout': 0.1
466 }
467
468 # ๋ชจ๋ธ ์์ฑ
469 model = BertModel(**config)
470
471 # ํ๋ผ๋ฏธํฐ ์
472 total_params = sum(p.numel() for p in model.parameters())
473 print(f"Total parameters: {total_params:,}")
474 print(f"Expected ~110M for BERT-Base\n")
475
476 # ํ
์คํธ ์
๋ ฅ
477 batch_size, seq_len = 2, 128
478 input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len))
479 attention_mask = torch.ones(batch_size, seq_len)
480 token_type_ids = torch.zeros(batch_size, seq_len, dtype=torch.long)
481
482 # Forward
483 outputs = model(
484 input_ids,
485 attention_mask=attention_mask,
486 token_type_ids=token_type_ids,
487 output_attentions=True
488 )
489
490 print("Output shapes:")
491 print(f" last_hidden_state: {outputs['last_hidden_state'].shape}")
492 print(f" pooler_output: {outputs['pooler_output'].shape}")
493 print(f" attentions: {len(outputs['attentions'])} layers")
494 print(f" attention shape: {outputs['attentions'][0].shape}")
495
496 # MLM ํ
์คํธ
497 print("\n=== MLM Test ===")
498 mlm_model = BertForMaskedLM(config)
499
500 labels = torch.randint(0, config['vocab_size'], (batch_size, seq_len))
501 labels[labels != 103] = -100 # [MASK] token๋ง ์์ธก
502
503 mlm_outputs = mlm_model(
504 input_ids,
505 attention_mask=attention_mask,
506 labels=labels
507 )
508
509 print(f"MLM Loss: {mlm_outputs['loss'].item():.4f}")
510 print(f"Logits shape: {mlm_outputs['logits'].shape}")
511
512 # Classification ํ
์คํธ
513 print("\n=== Classification Test ===")
514 clf_model = BertForSequenceClassification(config, num_labels=2)
515
516 labels = torch.randint(0, 2, (batch_size,))
517 clf_outputs = clf_model(
518 input_ids,
519 attention_mask=attention_mask,
520 labels=labels
521 )
522
523 print(f"Classification Loss: {clf_outputs['loss'].item():.4f}")
524 print(f"Logits shape: {clf_outputs['logits'].shape}")
525
526 print("\nAll tests passed!")