05_gpt_generation.py

Download
python 220 lines 6.3 KB
  1"""
  205. GPT ํ…์ŠคํŠธ ์ƒ์„ฑ ์˜ˆ์ œ
  3
  4GPT-2๋ฅผ ์‚ฌ์šฉํ•œ ํ…์ŠคํŠธ ์ƒ์„ฑ
  5"""
  6
  7print("=" * 60)
  8print("GPT ํ…์ŠคํŠธ ์ƒ์„ฑ")
  9print("=" * 60)
 10
 11try:
 12    import torch
 13    from transformers import GPT2Tokenizer, GPT2LMHeadModel
 14    import torch.nn.functional as F
 15
 16    # ============================================
 17    # 1. GPT-2 ๋กœ๋“œ
 18    # ============================================
 19    print("\n[1] GPT-2 ๋ชจ๋ธ ๋กœ๋“œ")
 20    print("-" * 40)
 21
 22    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
 23    model = GPT2LMHeadModel.from_pretrained('gpt2')
 24    model.eval()
 25
 26    # ํŒจ๋”ฉ ํ† ํฐ ์„ค์ •
 27    tokenizer.pad_token = tokenizer.eos_token
 28
 29    print(f"์–ดํœ˜ ํฌ๊ธฐ: {tokenizer.vocab_size}")
 30    print(f"๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ: {sum(p.numel() for p in model.parameters()):,}")
 31
 32
 33    # ============================================
 34    # 2. ๊ธฐ๋ณธ ์ƒ์„ฑ (Greedy)
 35    # ============================================
 36    print("\n[2] Greedy ์ƒ์„ฑ")
 37    print("-" * 40)
 38
 39    prompt = "Once upon a time"
 40    input_ids = tokenizer.encode(prompt, return_tensors='pt')
 41
 42    output = model.generate(
 43        input_ids,
 44        max_length=50,
 45        do_sample=False  # Greedy
 46    )
 47
 48    generated = tokenizer.decode(output[0], skip_special_tokens=True)
 49    print(f"ํ”„๋กฌํ”„ํŠธ: {prompt}")
 50    print(f"์ƒ์„ฑ: {generated}")
 51
 52
 53    # ============================================
 54    # 3. ์ƒ˜ํ”Œ๋ง ์ƒ์„ฑ
 55    # ============================================
 56    print("\n[3] Temperature ์ƒ˜ํ”Œ๋ง")
 57    print("-" * 40)
 58
 59    prompt = "The future of AI is"
 60    input_ids = tokenizer.encode(prompt, return_tensors='pt')
 61
 62    for temp in [0.5, 1.0, 1.5]:
 63        output = model.generate(
 64            input_ids,
 65            max_length=40,
 66            do_sample=True,
 67            temperature=temp,
 68            pad_token_id=tokenizer.eos_token_id
 69        )
 70        generated = tokenizer.decode(output[0], skip_special_tokens=True)
 71        print(f"temp={temp}: {generated[:60]}...")
 72
 73
 74    # ============================================
 75    # 4. Top-k / Top-p ์ƒ˜ํ”Œ๋ง
 76    # ============================================
 77    print("\n[4] Top-k / Top-p ์ƒ˜ํ”Œ๋ง")
 78    print("-" * 40)
 79
 80    prompt = "In the year 2050"
 81    input_ids = tokenizer.encode(prompt, return_tensors='pt')
 82
 83    # Top-k
 84    output_topk = model.generate(
 85        input_ids,
 86        max_length=50,
 87        do_sample=True,
 88        top_k=50,
 89        pad_token_id=tokenizer.eos_token_id
 90    )
 91    print(f"Top-k (k=50): {tokenizer.decode(output_topk[0], skip_special_tokens=True)[:70]}...")
 92
 93    # Top-p (Nucleus)
 94    output_topp = model.generate(
 95        input_ids,
 96        max_length=50,
 97        do_sample=True,
 98        top_p=0.9,
 99        pad_token_id=tokenizer.eos_token_id
100    )
101    print(f"Top-p (p=0.9): {tokenizer.decode(output_topp[0], skip_special_tokens=True)[:70]}...")
102
103
104    # ============================================
105    # 5. ๊ณ ๊ธ‰ ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ
106    # ============================================
107    print("\n[5] ๊ณ ๊ธ‰ ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ")
108    print("-" * 40)
109
110    prompt = "Python is a programming language"
111    input_ids = tokenizer.encode(prompt, return_tensors='pt')
112
113    output = model.generate(
114        input_ids,
115        max_length=80,
116        min_length=30,
117        do_sample=True,
118        temperature=0.8,
119        top_p=0.92,
120        top_k=50,
121        no_repeat_ngram_size=2,    # n-gram ๋ฐ˜๋ณต ๋ฐฉ์ง€
122        repetition_penalty=1.2,     # ๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ
123        num_return_sequences=2,     # ์—ฌ๋Ÿฌ ์‹œํ€€์Šค ์ƒ์„ฑ
124        pad_token_id=tokenizer.eos_token_id
125    )
126
127    print(f"ํ”„๋กฌํ”„ํŠธ: {prompt}")
128    for i, out in enumerate(output):
129        text = tokenizer.decode(out, skip_special_tokens=True)
130        print(f"\n์ƒ์„ฑ {i+1}: {text}")
131
132
133    # ============================================
134    # 6. ์ˆ˜๋™ ์ƒ์„ฑ ๋ฃจํ”„
135    # ============================================
136    print("\n[6] ์ˆ˜๋™ ์ƒ์„ฑ (Step-by-step)")
137    print("-" * 40)
138
139    def generate_manual(prompt, max_tokens=20, temperature=1.0):
140        input_ids = tokenizer.encode(prompt, return_tensors='pt')
141
142        for _ in range(max_tokens):
143            with torch.no_grad():
144                outputs = model(input_ids)
145                logits = outputs.logits[:, -1, :]  # ๋งˆ์ง€๋ง‰ ํ† ํฐ
146
147            # Temperature ์ ์šฉ
148            probs = F.softmax(logits / temperature, dim=-1)
149
150            # ์ƒ˜ํ”Œ๋ง
151            next_token = torch.multinomial(probs, num_samples=1)
152
153            # EOS ์ฒดํฌ
154            if next_token.item() == tokenizer.eos_token_id:
155                break
156
157            input_ids = torch.cat([input_ids, next_token], dim=-1)
158
159        return tokenizer.decode(input_ids[0], skip_special_tokens=True)
160
161    result = generate_manual("The robot said", max_tokens=15, temperature=0.8)
162    print(f"์ˆ˜๋™ ์ƒ์„ฑ: {result}")
163
164
165    # ============================================
166    # 7. ์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ (ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฐ˜)
167    # ============================================
168    print("\n[7] ์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ")
169    print("-" * 40)
170
171    prompts = [
172        "Q: What is machine learning?\nA:",
173        "Translate English to French: Hello, how are you? โ†’",
174        "Summarize: Artificial intelligence is transforming various industries. โ†’"
175    ]
176
177    for prompt in prompts:
178        input_ids = tokenizer.encode(prompt, return_tensors='pt')
179        output = model.generate(
180            input_ids,
181            max_new_tokens=30,
182            do_sample=True,
183            temperature=0.7,
184            pad_token_id=tokenizer.eos_token_id
185        )
186        result = tokenizer.decode(output[0], skip_special_tokens=True)
187        print(f"์ž…๋ ฅ: {prompt[:50]}...")
188        print(f"์ถœ๋ ฅ: {result[len(prompt):len(prompt)+60]}...")
189        print()
190
191
192    # ============================================
193    # ์ •๋ฆฌ
194    # ============================================
195    print("=" * 60)
196    print("GPT ์ƒ์„ฑ ์ •๋ฆฌ")
197    print("=" * 60)
198
199    summary = """
200์ƒ์„ฑ ์ „๋žต:
201    - Greedy: do_sample=False, ๊ฒฐ์ •์ 
202    - Temperature: ๋‚ฎ์œผ๋ฉด ๊ฒฐ์ •์ , ๋†’์œผ๋ฉด ๋‹ค์–‘
203    - Top-k: ์ƒ์œ„ k๊ฐœ ํ† ํฐ์—์„œ ์ƒ˜ํ”Œ๋ง
204    - Top-p (Nucleus): ๋ˆ„์  ํ™•๋ฅ  p๊นŒ์ง€ ์ƒ˜ํ”Œ๋ง
205
206ํ•ต์‹ฌ ์ฝ”๋“œ:
207    output = model.generate(
208        input_ids,
209        max_length=50,
210        do_sample=True,
211        temperature=0.8,
212        top_p=0.9
213    )
214"""
215    print(summary)
216
217except ImportError as e:
218    print(f"ํ•„์š” ํŒจํ‚ค์ง€ ๋ฏธ์„ค์น˜: {e}")
219    print("pip install torch transformers")