bert_lowlevel.py

Download
python 527 lines 15.8 KB
  1"""
  2PyTorch Low-Level BERT ๊ตฌํ˜„
  3
  4nn.TransformerEncoder ๋ฏธ์‚ฌ์šฉ
  5F.linear, F.layer_norm ๋“ฑ ๊ธฐ๋ณธ ์—ฐ์‚ฐ๋งŒ ์‚ฌ์šฉ
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11import math
 12from typing import Optional, Tuple
 13
 14
 15class BertEmbeddings(nn.Module):
 16    """
 17    BERT Embeddings = Token + Segment + Position
 18
 19    Token: ๋‹จ์–ด ์˜๋ฏธ
 20    Segment: ๋ฌธ์žฅ A/B ๊ตฌ๋ถ„
 21    Position: ์œ„์น˜ ์ •๋ณด (ํ•™์Šต ๊ฐ€๋Šฅ)
 22    """
 23
 24    def __init__(
 25        self,
 26        vocab_size: int,
 27        hidden_size: int,
 28        max_position: int = 512,
 29        type_vocab_size: int = 2,  # ๋ฌธ์žฅ A, B
 30        dropout: float = 0.1
 31    ):
 32        super().__init__()
 33        self.hidden_size = hidden_size
 34
 35        # Embedding tables (nn.Embedding ์‚ฌ์šฉํ•˜์ง€๋งŒ ๊ฐœ๋…์ ์œผ๋กœ lookup)
 36        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
 37        self.position_embeddings = nn.Embedding(max_position, hidden_size)
 38        self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
 39
 40        # Layer Norm + Dropout
 41        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
 42        self.dropout = nn.Dropout(dropout)
 43
 44    def forward(
 45        self,
 46        input_ids: torch.Tensor,
 47        token_type_ids: Optional[torch.Tensor] = None,
 48        position_ids: Optional[torch.Tensor] = None
 49    ) -> torch.Tensor:
 50        """
 51        Args:
 52            input_ids: (batch, seq_len) ํ† ํฐ ID
 53            token_type_ids: (batch, seq_len) ์„ธ๊ทธ๋จผํŠธ ID (0 or 1)
 54            position_ids: (batch, seq_len) ์œ„์น˜ ID
 55
 56        Returns:
 57            embeddings: (batch, seq_len, hidden_size)
 58        """
 59        batch_size, seq_len = input_ids.shape
 60
 61        # ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
 62        if token_type_ids is None:
 63            token_type_ids = torch.zeros_like(input_ids)
 64
 65        if position_ids is None:
 66            position_ids = torch.arange(seq_len, device=input_ids.device)
 67            position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
 68
 69        # ์„ธ ๊ฐ€์ง€ ์ž„๋ฒ ๋”ฉ ํ•ฉ์‚ฐ
 70        word_emb = self.word_embeddings(input_ids)
 71        position_emb = self.position_embeddings(position_ids)
 72        token_type_emb = self.token_type_embeddings(token_type_ids)
 73
 74        embeddings = word_emb + position_emb + token_type_emb
 75
 76        # Layer Norm + Dropout
 77        embeddings = self.layer_norm(embeddings)
 78        embeddings = self.dropout(embeddings)
 79
 80        return embeddings
 81
 82
 83class BertSelfAttention(nn.Module):
 84    """Multi-Head Self-Attention (Low-Level)"""
 85
 86    def __init__(
 87        self,
 88        hidden_size: int,
 89        num_heads: int,
 90        dropout: float = 0.1
 91    ):
 92        super().__init__()
 93        assert hidden_size % num_heads == 0
 94
 95        self.num_heads = num_heads
 96        self.head_dim = hidden_size // num_heads
 97        self.scale = math.sqrt(self.head_dim)
 98
 99        # Q, K, V projections (nn.Linear ๋Œ€์‹  ํŒŒ๋ผ๋ฏธํ„ฐ ์ง์ ‘ ๊ด€๋ฆฌ ๊ฐ€๋Šฅ)
100        self.query = nn.Linear(hidden_size, hidden_size)
101        self.key = nn.Linear(hidden_size, hidden_size)
102        self.value = nn.Linear(hidden_size, hidden_size)
103
104        self.dropout = nn.Dropout(dropout)
105
106    def forward(
107        self,
108        hidden_states: torch.Tensor,
109        attention_mask: Optional[torch.Tensor] = None
110    ) -> Tuple[torch.Tensor, torch.Tensor]:
111        """
112        Args:
113            hidden_states: (batch, seq_len, hidden_size)
114            attention_mask: (batch, 1, 1, seq_len) ๋˜๋Š” (batch, seq_len)
115
116        Returns:
117            context: (batch, seq_len, hidden_size)
118            attention_weights: (batch, num_heads, seq_len, seq_len)
119        """
120        batch_size, seq_len, _ = hidden_states.shape
121
122        # Q, K, V ๊ณ„์‚ฐ
123        Q = self.query(hidden_states)
124        K = self.key(hidden_states)
125        V = self.value(hidden_states)
126
127        # Multi-head reshape: (batch, seq, hidden) โ†’ (batch, heads, seq, head_dim)
128        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
129        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
130        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
131
132        # Attention scores: (batch, heads, seq, seq)
133        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
134
135        # Attention mask ์ ์šฉ
136        if attention_mask is not None:
137            # (batch, seq) โ†’ (batch, 1, 1, seq)
138            if attention_mask.dim() == 2:
139                attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
140            # 0 โ†’ -inf, 1 โ†’ 0
141            attention_mask = (1.0 - attention_mask) * -10000.0
142            scores = scores + attention_mask
143
144        # Softmax + Dropout
145        attention_weights = F.softmax(scores, dim=-1)
146        attention_weights = self.dropout(attention_weights)
147
148        # Context: (batch, heads, seq, head_dim)
149        context = torch.matmul(attention_weights, V)
150
151        # Reshape back: (batch, seq, hidden)
152        context = context.transpose(1, 2).contiguous()
153        context = context.view(batch_size, seq_len, -1)
154
155        return context, attention_weights
156
157
158class BertSelfOutput(nn.Module):
159    """Attention Output (projection + residual + layer norm)"""
160
161    def __init__(self, hidden_size: int, dropout: float = 0.1):
162        super().__init__()
163        self.dense = nn.Linear(hidden_size, hidden_size)
164        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
165        self.dropout = nn.Dropout(dropout)
166
167    def forward(
168        self,
169        hidden_states: torch.Tensor,
170        input_tensor: torch.Tensor
171    ) -> torch.Tensor:
172        """
173        Args:
174            hidden_states: attention ์ถœ๋ ฅ
175            input_tensor: residual connection์šฉ ์›๋ณธ ์ž…๋ ฅ
176        """
177        hidden_states = self.dense(hidden_states)
178        hidden_states = self.dropout(hidden_states)
179        # Residual + Layer Norm
180        hidden_states = self.layer_norm(hidden_states + input_tensor)
181        return hidden_states
182
183
184class BertIntermediate(nn.Module):
185    """Feed-Forward ์ฒซ ๋ฒˆ์งธ ์ธต (ํ™•์žฅ)"""
186
187    def __init__(self, hidden_size: int, intermediate_size: int):
188        super().__init__()
189        self.dense = nn.Linear(hidden_size, intermediate_size)
190        # BERT๋Š” GELU ์‚ฌ์šฉ
191
192    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
193        hidden_states = self.dense(hidden_states)
194        hidden_states = F.gelu(hidden_states)
195        return hidden_states
196
197
198class BertOutput(nn.Module):
199    """Feed-Forward ๋‘ ๋ฒˆ์งธ ์ธต (์ถ•์†Œ) + Residual"""
200
201    def __init__(
202        self,
203        hidden_size: int,
204        intermediate_size: int,
205        dropout: float = 0.1
206    ):
207        super().__init__()
208        self.dense = nn.Linear(intermediate_size, hidden_size)
209        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
210        self.dropout = nn.Dropout(dropout)
211
212    def forward(
213        self,
214        hidden_states: torch.Tensor,
215        input_tensor: torch.Tensor
216    ) -> torch.Tensor:
217        hidden_states = self.dense(hidden_states)
218        hidden_states = self.dropout(hidden_states)
219        hidden_states = self.layer_norm(hidden_states + input_tensor)
220        return hidden_states
221
222
223class BertLayer(nn.Module):
224    """Single BERT Encoder Layer"""
225
226    def __init__(
227        self,
228        hidden_size: int,
229        num_heads: int,
230        intermediate_size: int,
231        dropout: float = 0.1
232    ):
233        super().__init__()
234        # Self-Attention
235        self.attention = BertSelfAttention(hidden_size, num_heads, dropout)
236        self.attention_output = BertSelfOutput(hidden_size, dropout)
237
238        # Feed-Forward
239        self.intermediate = BertIntermediate(hidden_size, intermediate_size)
240        self.output = BertOutput(hidden_size, intermediate_size, dropout)
241
242    def forward(
243        self,
244        hidden_states: torch.Tensor,
245        attention_mask: Optional[torch.Tensor] = None
246    ) -> Tuple[torch.Tensor, torch.Tensor]:
247        # Self-Attention
248        attention_output, attention_weights = self.attention(
249            hidden_states, attention_mask
250        )
251        attention_output = self.attention_output(attention_output, hidden_states)
252
253        # Feed-Forward
254        intermediate_output = self.intermediate(attention_output)
255        layer_output = self.output(intermediate_output, attention_output)
256
257        return layer_output, attention_weights
258
259
260class BertEncoder(nn.Module):
261    """BERT Encoder (stacked layers)"""
262
263    def __init__(
264        self,
265        num_layers: int,
266        hidden_size: int,
267        num_heads: int,
268        intermediate_size: int,
269        dropout: float = 0.1
270    ):
271        super().__init__()
272        self.layers = nn.ModuleList([
273            BertLayer(hidden_size, num_heads, intermediate_size, dropout)
274            for _ in range(num_layers)
275        ])
276
277    def forward(
278        self,
279        hidden_states: torch.Tensor,
280        attention_mask: Optional[torch.Tensor] = None,
281        output_attentions: bool = False
282    ) -> Tuple[torch.Tensor, Optional[list]]:
283        all_attentions = [] if output_attentions else None
284
285        for layer in self.layers:
286            hidden_states, attention_weights = layer(hidden_states, attention_mask)
287            if output_attentions:
288                all_attentions.append(attention_weights)
289
290        return hidden_states, all_attentions
291
292
293class BertPooler(nn.Module):
294    """[CLS] ํ† ํฐ ํ’€๋ง"""
295
296    def __init__(self, hidden_size: int):
297        super().__init__()
298        self.dense = nn.Linear(hidden_size, hidden_size)
299
300    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
301        """
302        Args:
303            hidden_states: (batch, seq_len, hidden_size)
304
305        Returns:
306            pooled: (batch, hidden_size) - [CLS] ํ† ํฐ์˜ ํ‘œํ˜„
307        """
308        # [CLS] ํ† ํฐ (์ฒซ ๋ฒˆ์งธ ํ† ํฐ)
309        cls_token = hidden_states[:, 0]
310        pooled = self.dense(cls_token)
311        pooled = torch.tanh(pooled)
312        return pooled
313
314
315class BertModel(nn.Module):
316    """BERT Base Model"""
317
318    def __init__(
319        self,
320        vocab_size: int = 30522,
321        hidden_size: int = 768,
322        num_layers: int = 12,
323        num_heads: int = 12,
324        intermediate_size: int = 3072,
325        max_position: int = 512,
326        type_vocab_size: int = 2,
327        dropout: float = 0.1
328    ):
329        super().__init__()
330        self.hidden_size = hidden_size
331
332        self.embeddings = BertEmbeddings(
333            vocab_size, hidden_size, max_position, type_vocab_size, dropout
334        )
335        self.encoder = BertEncoder(
336            num_layers, hidden_size, num_heads, intermediate_size, dropout
337        )
338        self.pooler = BertPooler(hidden_size)
339
340    def forward(
341        self,
342        input_ids: torch.Tensor,
343        attention_mask: Optional[torch.Tensor] = None,
344        token_type_ids: Optional[torch.Tensor] = None,
345        output_attentions: bool = False
346    ):
347        """
348        Args:
349            input_ids: (batch, seq_len)
350            attention_mask: (batch, seq_len) - 1 for valid, 0 for padding
351            token_type_ids: (batch, seq_len) - 0 for sent A, 1 for sent B
352
353        Returns:
354            last_hidden_state: (batch, seq_len, hidden_size)
355            pooler_output: (batch, hidden_size)
356            attentions: optional list of attention weights
357        """
358        # Embeddings
359        embeddings = self.embeddings(
360            input_ids, token_type_ids=token_type_ids
361        )
362
363        # Encoder
364        encoder_output, attentions = self.encoder(
365            embeddings, attention_mask, output_attentions
366        )
367
368        # Pooler
369        pooled_output = self.pooler(encoder_output)
370
371        return {
372            'last_hidden_state': encoder_output,
373            'pooler_output': pooled_output,
374            'attentions': attentions
375        }
376
377
378class BertForMaskedLM(nn.Module):
379    """BERT for Masked Language Modeling"""
380
381    def __init__(self, config: dict):
382        super().__init__()
383        self.bert = BertModel(**config)
384
385        # MLM Head
386        self.cls = nn.Sequential(
387            nn.Linear(config['hidden_size'], config['hidden_size']),
388            nn.GELU(),
389            nn.LayerNorm(config['hidden_size'], eps=1e-12),
390            nn.Linear(config['hidden_size'], config['vocab_size'])
391        )
392
393    def forward(
394        self,
395        input_ids: torch.Tensor,
396        attention_mask: Optional[torch.Tensor] = None,
397        token_type_ids: Optional[torch.Tensor] = None,
398        labels: Optional[torch.Tensor] = None
399    ):
400        outputs = self.bert(input_ids, attention_mask, token_type_ids)
401        hidden_states = outputs['last_hidden_state']
402
403        # MLM predictions
404        prediction_scores = self.cls(hidden_states)
405
406        loss = None
407        if labels is not None:
408            loss = F.cross_entropy(
409                prediction_scores.view(-1, prediction_scores.size(-1)),
410                labels.view(-1),
411                ignore_index=-100
412            )
413
414        return {
415            'loss': loss,
416            'logits': prediction_scores,
417            'hidden_states': hidden_states
418        }
419
420
421class BertForSequenceClassification(nn.Module):
422    """BERT for Sequence Classification"""
423
424    def __init__(self, config: dict, num_labels: int):
425        super().__init__()
426        self.bert = BertModel(**config)
427        self.classifier = nn.Linear(config['hidden_size'], num_labels)
428        self.dropout = nn.Dropout(config.get('dropout', 0.1))
429
430    def forward(
431        self,
432        input_ids: torch.Tensor,
433        attention_mask: Optional[torch.Tensor] = None,
434        token_type_ids: Optional[torch.Tensor] = None,
435        labels: Optional[torch.Tensor] = None
436    ):
437        outputs = self.bert(input_ids, attention_mask, token_type_ids)
438        pooled_output = outputs['pooler_output']
439
440        pooled_output = self.dropout(pooled_output)
441        logits = self.classifier(pooled_output)
442
443        loss = None
444        if labels is not None:
445            loss = F.cross_entropy(logits, labels)
446
447        return {
448            'loss': loss,
449            'logits': logits
450        }
451
452
453# ํ…Œ์ŠคํŠธ
454if __name__ == "__main__":
455    print("=== BERT Low-Level Implementation Test ===\n")
456
457    # ์„ค์ •
458    config = {
459        'vocab_size': 30522,
460        'hidden_size': 768,
461        'num_layers': 12,
462        'num_heads': 12,
463        'intermediate_size': 3072,
464        'max_position': 512,
465        'dropout': 0.1
466    }
467
468    # ๋ชจ๋ธ ์ƒ์„ฑ
469    model = BertModel(**config)
470
471    # ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
472    total_params = sum(p.numel() for p in model.parameters())
473    print(f"Total parameters: {total_params:,}")
474    print(f"Expected ~110M for BERT-Base\n")
475
476    # ํ…Œ์ŠคํŠธ ์ž…๋ ฅ
477    batch_size, seq_len = 2, 128
478    input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len))
479    attention_mask = torch.ones(batch_size, seq_len)
480    token_type_ids = torch.zeros(batch_size, seq_len, dtype=torch.long)
481
482    # Forward
483    outputs = model(
484        input_ids,
485        attention_mask=attention_mask,
486        token_type_ids=token_type_ids,
487        output_attentions=True
488    )
489
490    print("Output shapes:")
491    print(f"  last_hidden_state: {outputs['last_hidden_state'].shape}")
492    print(f"  pooler_output: {outputs['pooler_output'].shape}")
493    print(f"  attentions: {len(outputs['attentions'])} layers")
494    print(f"  attention shape: {outputs['attentions'][0].shape}")
495
496    # MLM ํ…Œ์ŠคํŠธ
497    print("\n=== MLM Test ===")
498    mlm_model = BertForMaskedLM(config)
499
500    labels = torch.randint(0, config['vocab_size'], (batch_size, seq_len))
501    labels[labels != 103] = -100  # [MASK] token๋งŒ ์˜ˆ์ธก
502
503    mlm_outputs = mlm_model(
504        input_ids,
505        attention_mask=attention_mask,
506        labels=labels
507    )
508
509    print(f"MLM Loss: {mlm_outputs['loss'].item():.4f}")
510    print(f"Logits shape: {mlm_outputs['logits'].shape}")
511
512    # Classification ํ…Œ์ŠคํŠธ
513    print("\n=== Classification Test ===")
514    clf_model = BertForSequenceClassification(config, num_labels=2)
515
516    labels = torch.randint(0, 2, (batch_size,))
517    clf_outputs = clf_model(
518        input_ids,
519        attention_mask=attention_mask,
520        labels=labels
521    )
522
523    print(f"Classification Loss: {clf_outputs['loss'].item():.4f}")
524    print(f"Logits shape: {clf_outputs['logits'].shape}")
525
526    print("\nAll tests passed!")