04_bert_basics.py

Download
python 209 lines 6.0 KB
  1"""
  204. BERT ๊ธฐ์ดˆ - HuggingFace BERT ์‚ฌ์šฉ ์˜ˆ์ œ
  3
  4BERT ๋ชจ๋ธ ๋กœ๋“œ, ์ž„๋ฒ ๋”ฉ, ๋ถ„๋ฅ˜
  5"""
  6
  7print("=" * 60)
  8print("BERT ๊ธฐ์ดˆ")
  9print("=" * 60)
 10
 11try:
 12    import torch
 13    from transformers import BertTokenizer, BertModel, BertForSequenceClassification
 14    import torch.nn.functional as F
 15
 16    # ============================================
 17    # 1. ํ† ํฌ๋‚˜์ด์ €์™€ ๋ชจ๋ธ ๋กœ๋“œ
 18    # ============================================
 19    print("\n[1] BERT ๋ชจ๋ธ ๋กœ๋“œ")
 20    print("-" * 40)
 21
 22    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 23    model = BertModel.from_pretrained('bert-base-uncased')
 24
 25    print(f"์–ดํœ˜ ํฌ๊ธฐ: {tokenizer.vocab_size}")
 26    print(f"๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ: {sum(p.numel() for p in model.parameters()):,}")
 27
 28
 29    # ============================================
 30    # 2. ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
 31    # ============================================
 32    print("\n[2] ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ")
 33    print("-" * 40)
 34
 35    text = "Hello, how are you?"
 36
 37    # ํ† ํฐํ™”
 38    tokens = tokenizer.tokenize(text)
 39    print(f"ํ…์ŠคํŠธ: {text}")
 40    print(f"ํ† ํฐ: {tokens}")
 41
 42    # ์ธ์ฝ”๋”ฉ
 43    encoded = tokenizer(text, return_tensors='pt')
 44    print(f"input_ids: {encoded['input_ids']}")
 45    print(f"attention_mask: {encoded['attention_mask']}")
 46
 47    # ๋””์ฝ”๋”ฉ
 48    decoded = tokenizer.decode(encoded['input_ids'][0])
 49    print(f"๋””์ฝ”๋”ฉ: {decoded}")
 50
 51
 52    # ============================================
 53    # 3. BERT ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
 54    # ============================================
 55    print("\n[3] BERT ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ")
 56    print("-" * 40)
 57
 58    model.eval()
 59    with torch.no_grad():
 60        outputs = model(**encoded)
 61
 62    # ์ถœ๋ ฅ ๊ตฌ์กฐ
 63    last_hidden_state = outputs.last_hidden_state  # (batch, seq, hidden)
 64    pooler_output = outputs.pooler_output          # (batch, hidden) - [CLS] ๋ณ€ํ™˜
 65
 66    print(f"last_hidden_state shape: {last_hidden_state.shape}")
 67    print(f"pooler_output shape: {pooler_output.shape}")
 68
 69    # [CLS] ํ† ํฐ ์ž„๋ฒ ๋”ฉ
 70    cls_embedding = last_hidden_state[0, 0]  # ์ฒซ ๋ฒˆ์งธ ํ† ํฐ
 71    print(f"[CLS] ์ž„๋ฒ ๋”ฉ shape: {cls_embedding.shape}")
 72
 73
 74    # ============================================
 75    # 4. ๋ฌธ์žฅ ์Œ ์ธ์ฝ”๋”ฉ
 76    # ============================================
 77    print("\n[4] ๋ฌธ์žฅ ์Œ ์ธ์ฝ”๋”ฉ")
 78    print("-" * 40)
 79
 80    text_a = "How old are you?"
 81    text_b = "I am 25 years old."
 82
 83    encoded_pair = tokenizer(text_a, text_b, return_tensors='pt')
 84    print(f"๋ฌธ์žฅ A: {text_a}")
 85    print(f"๋ฌธ์žฅ B: {text_b}")
 86    print(f"token_type_ids: {encoded_pair['token_type_ids']}")
 87    # [0, 0, ..., 0, 1, 1, ..., 1] - A๋Š” 0, B๋Š” 1
 88
 89
 90    # ============================================
 91    # 5. ๋ฌธ์žฅ ๋ถ„๋ฅ˜
 92    # ============================================
 93    print("\n[5] ๋ฌธ์žฅ ๋ถ„๋ฅ˜")
 94    print("-" * 40)
 95
 96    # ๊ฐ์„ฑ ๋ถ„์„ ๋ชจ๋ธ ๋กœ๋“œ
 97    classifier = BertForSequenceClassification.from_pretrained(
 98        'bert-base-uncased',
 99        num_labels=2
100    )
101
102    texts = [
103        "I love this movie! It's amazing.",
104        "This is terrible. I hate it.",
105        "The weather is nice today."
106    ]
107
108    classifier.eval()
109    for text in texts:
110        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
111
112        with torch.no_grad():
113            outputs = classifier(**inputs)
114            logits = outputs.logits
115            probs = F.softmax(logits, dim=-1)
116            pred = logits.argmax(dim=-1).item()
117
118        label = "Positive" if pred == 1 else "Negative"
119        conf = probs[0, pred].item()
120        print(f"[{label}] ({conf:.2%}) {text[:40]}...")
121
122
123    # ============================================
124    # 6. ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
125    # ============================================
126    print("\n[6] ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ")
127    print("-" * 40)
128
129    texts = ["Hello world", "How are you?", "I'm fine, thanks!"]
130
131    # ๋ฐฐ์น˜ ์ธ์ฝ”๋”ฉ
132    batch_encoded = tokenizer(
133        texts,
134        padding=True,
135        truncation=True,
136        max_length=32,
137        return_tensors='pt'
138    )
139
140    print(f"๋ฐฐ์น˜ input_ids shape: {batch_encoded['input_ids'].shape}")
141
142    # ๋ฐฐ์น˜ ์ถ”๋ก 
143    model.eval()
144    with torch.no_grad():
145        batch_outputs = model(**batch_encoded)
146
147    print(f"๋ฐฐ์น˜ ์ถœ๋ ฅ shape: {batch_outputs.last_hidden_state.shape}")
148
149
150    # ============================================
151    # 7. ๋ฌธ์žฅ ์œ ์‚ฌ๋„
152    # ============================================
153    print("\n[7] ๋ฌธ์žฅ ์œ ์‚ฌ๋„")
154    print("-" * 40)
155
156    def get_sentence_embedding(text, model, tokenizer):
157        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
158        with torch.no_grad():
159            outputs = model(**inputs)
160        # [CLS] ํ† ํฐ ๋˜๋Š” ํ‰๊ท  ํ’€๋ง
161        return outputs.last_hidden_state.mean(dim=1).squeeze()
162
163    sentences = [
164        "I love programming",
165        "Coding is my passion",
166        "I enjoy eating pizza"
167    ]
168
169    embeddings = [get_sentence_embedding(s, model, tokenizer) for s in sentences]
170
171    print("๋ฌธ์žฅ ์œ ์‚ฌ๋„:")
172    for i in range(len(sentences)):
173        for j in range(i+1, len(sentences)):
174            sim = F.cosine_similarity(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0))
175            print(f"  '{sentences[i][:20]}...' vs '{sentences[j][:20]}...': {sim.item():.4f}")
176
177
178    # ============================================
179    # ์ •๋ฆฌ
180    # ============================================
181    print("\n" + "=" * 60)
182    print("BERT ์ •๋ฆฌ")
183    print("=" * 60)
184
185    summary = """
186BERT ์‚ฌ์šฉ ํŒจํ„ด:
187    # ๋กœ๋“œ
188    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
189    model = BertModel.from_pretrained('bert-base-uncased')
190
191    # ์ธ์ฝ”๋”ฉ
192    inputs = tokenizer(text, return_tensors='pt')
193
194    # ์ž„๋ฒ ๋”ฉ
195    outputs = model(**inputs)
196    cls_embedding = outputs.last_hidden_state[:, 0]  # [CLS]
197
198    # ๋ถ„๋ฅ˜
199    classifier = BertForSequenceClassification.from_pretrained(
200        'bert-base-uncased', num_labels=2
201    )
202    logits = classifier(**inputs).logits
203"""
204    print(summary)
205
206except ImportError as e:
207    print(f"ํ•„์š” ํŒจํ‚ค์ง€ ๋ฏธ์„ค์น˜: {e}")
208    print("pip install torch transformers")