14_rlhf_reward_model.py

Download
python 461 lines 13.9 KB
  1"""
  214. RLHF์™€ LLM ์ •๋ ฌ (Alignment) ์˜ˆ์ œ
  3
  4Reward Model, PPO, DPO, Constitutional AI ์‹ค์Šต
  5"""
  6
  7import numpy as np
  8import random
  9
 10print("=" * 60)
 11print("RLHF์™€ LLM ์ •๋ ฌ (Alignment)")
 12print("=" * 60)
 13
 14
 15# ============================================
 16# 1. ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ ์ดํ•ด
 17# ============================================
 18print("\n[1] ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ ํ˜•์‹")
 19print("-" * 40)
 20
 21# ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ ์˜ˆ์‹œ
 22preference_data = [
 23    {
 24        "prompt": "์ธ๊ณต์ง€๋Šฅ์ด๋ž€ ๋ฌด์—‡์ธ๊ฐ€์š”?",
 25        "chosen": "์ธ๊ณต์ง€๋Šฅ(AI)์€ ์ปดํ“จํ„ฐ ์‹œ์Šคํ…œ์ด ์ธ๊ฐ„์˜ ์ง€๋Šฅ์„ ๋ชจ๋ฐฉํ•˜์—ฌ ํ•™์Šต, ์ถ”๋ก , "
 26                  "๋ฌธ์ œ ํ•ด๊ฒฐ ๋“ฑ์˜ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค. ๋จธ์‹ ๋Ÿฌ๋‹, ๋”ฅ๋Ÿฌ๋‹, "
 27                  "์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๋“ฑ ๋‹ค์–‘ํ•œ ๋ถ„์•ผ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.",
 28        "rejected": "AI๋Š” ์ปดํ“จํ„ฐ๊ฐ€ ๋˜‘๋˜‘ํ•ด์ง€๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค."
 29    },
 30    {
 31        "prompt": "ํŒŒ์ด์ฌ์˜ ์žฅ์ ์€?",
 32        "chosen": "ํŒŒ์ด์ฌ์˜ ์ฃผ์š” ์žฅ์ ์€ 1) ์ฝ๊ธฐ ์‰ฌ์šด ๋ฌธ๋ฒ•, 2) ํ’๋ถ€ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ, "
 33                  "3) ๋‹ค์–‘ํ•œ ๋ถ„์•ผ ์ ์šฉ ๊ฐ€๋Šฅ, 4) ํ™œ๋ฐœํ•œ ์ปค๋ฎค๋‹ˆํ‹ฐ์ž…๋‹ˆ๋‹ค.",
 34        "rejected": "ํŒŒ์ด์ฌ์€ ์ข‹์€ ์–ธ์–ด์ž…๋‹ˆ๋‹ค."
 35    },
 36    {
 37        "prompt": "์šด๋™์˜ ํšจ๊ณผ๋Š”?",
 38        "chosen": "๊ทœ์น™์ ์ธ ์šด๋™์€ ์‹ฌํ˜ˆ๊ด€ ๊ฑด๊ฐ• ๊ฐœ์„ , ์ฒด์ค‘ ๊ด€๋ฆฌ, ๊ทผ๋ ฅ ๊ฐ•ํ™”, "
 39                  "์ •์‹  ๊ฑด๊ฐ• ํ–ฅ์ƒ, ์ˆ˜๋ฉด ์งˆ ๊ฐœ์„  ๋“ฑ ๋‹ค์–‘ํ•œ ํšจ๊ณผ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.",
 40        "rejected": "์šด๋™ํ•˜๋ฉด ๊ฑด๊ฐ•ํ•ด์ง‘๋‹ˆ๋‹ค."
 41    }
 42]
 43
 44print("์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ ์˜ˆ์‹œ:")
 45for i, data in enumerate(preference_data):
 46    print(f"\n{i+1}. ํ”„๋กฌํ”„ํŠธ: {data['prompt']}")
 47    print(f"   ์„ ํ˜ธ ์‘๋‹ต: {data['chosen'][:50]}...")
 48    print(f"   ๋น„์„ ํ˜ธ ์‘๋‹ต: {data['rejected']}")
 49
 50
 51# ============================================
 52# 2. ๊ฐ„๋‹จํ•œ Reward Model ์‹œ๋ฎฌ๋ ˆ์ด์…˜
 53# ============================================
 54print("\n[2] ๊ฐ„๋‹จํ•œ Reward Model")
 55print("-" * 40)
 56
 57class SimpleRewardModel:
 58    """๊ฐ„๋‹จํ•œ ๊ทœ์น™ ๊ธฐ๋ฐ˜ Reward Model (์‹œ๋ฎฌ๋ ˆ์ด์…˜์šฉ)"""
 59
 60    def __init__(self):
 61        self.positive_factors = {
 62            "length": 0.3,        # ์ ์ ˆํ•œ ๊ธธ์ด
 63            "detail": 0.3,        # ์ƒ์„ธํ•จ
 64            "structure": 0.2,     # ๊ตฌ์กฐํ™”
 65            "politeness": 0.2     # ์ •์ค‘ํ•จ
 66        }
 67
 68    def compute_reward(self, prompt, response):
 69        """์‘๋‹ต์— ๋Œ€ํ•œ ๋ณด์ƒ ์ ์ˆ˜ ๊ณ„์‚ฐ"""
 70        score = 0.0
 71
 72        # 1. ๊ธธ์ด ์ ์ˆ˜ (50-300์ž ์ตœ์ )
 73        length = len(response)
 74        if 50 <= length <= 300:
 75            score += self.positive_factors["length"]
 76        elif length > 300:
 77            score += self.positive_factors["length"] * 0.5
 78
 79        # 2. ์ƒ์„ธํ•จ (์ˆซ์ž, ์˜ˆ์‹œ ํฌํ•จ)
 80        if any(c.isdigit() for c in response):
 81            score += self.positive_factors["detail"] * 0.5
 82        if "์˜ˆ๋ฅผ ๋“ค์–ด" in response or "์˜ˆ์‹œ" in response:
 83            score += self.positive_factors["detail"] * 0.5
 84
 85        # 3. ๊ตฌ์กฐํ™” (์‰ผํ‘œ, ๋งˆ์นจํ‘œ ์‚ฌ์šฉ)
 86        if response.count(',') >= 2:
 87            score += self.positive_factors["structure"]
 88
 89        # 4. ์ •์ค‘ํ•จ
 90        polite_words = ["์ž…๋‹ˆ๋‹ค", "์Šต๋‹ˆ๋‹ค", "๋ฉ๋‹ˆ๋‹ค"]
 91        if any(word in response for word in polite_words):
 92            score += self.positive_factors["politeness"]
 93
 94        return score
 95
 96# ํ…Œ์ŠคํŠธ
 97reward_model = SimpleRewardModel()
 98
 99print("Reward Model ํ…Œ์ŠคํŠธ:")
100for data in preference_data:
101    chosen_reward = reward_model.compute_reward(data["prompt"], data["chosen"])
102    rejected_reward = reward_model.compute_reward(data["prompt"], data["rejected"])
103    print(f"\nํ”„๋กฌํ”„ํŠธ: {data['prompt']}")
104    print(f"  ์„ ํ˜ธ ์‘๋‹ต ์ ์ˆ˜: {chosen_reward:.2f}")
105    print(f"  ๋น„์„ ํ˜ธ ์‘๋‹ต ์ ์ˆ˜: {rejected_reward:.2f}")
106    print(f"  ์ •๋ ฌ ์—ฌ๋ถ€: {'OK' if chosen_reward > rejected_reward else 'FAIL'}")
107
108
109# ============================================
110# 3. Bradley-Terry ๋ชจ๋ธ (DPO ๊ธฐ๋ฐ˜)
111# ============================================
112print("\n[3] Bradley-Terry ๋ชจ๋ธ (์„ ํ˜ธ๋„ ํ™•๋ฅ )")
113print("-" * 40)
114
115def bradley_terry_probability(reward_chosen, reward_rejected, beta=1.0):
116    """
117    Bradley-Terry ๋ชจ๋ธ๋กœ ์„ ํ˜ธ ํ™•๋ฅ  ๊ณ„์‚ฐ
118
119    P(chosen > rejected) = sigmoid(beta * (r_chosen - r_rejected))
120    """
121    diff = reward_chosen - reward_rejected
122    prob = 1 / (1 + np.exp(-beta * diff))
123    return prob
124
125def dpo_loss(reward_chosen, reward_rejected, beta=0.1):
126    """
127    DPO ์†์‹ค ํ•จ์ˆ˜ (๊ฐ„๋‹จํ•œ ๋ฒ„์ „)
128
129    L = -log(sigmoid(beta * (r_chosen - r_rejected)))
130    """
131    prob = bradley_terry_probability(reward_chosen, reward_rejected, beta)
132    loss = -np.log(prob + 1e-10)
133    return loss
134
135# ํ…Œ์ŠคํŠธ
136print("Bradley-Terry ์„ ํ˜ธ ํ™•๋ฅ :")
137for r_c, r_r in [(0.8, 0.3), (0.5, 0.5), (0.3, 0.7)]:
138    prob = bradley_terry_probability(r_c, r_r, beta=2.0)
139    loss = dpo_loss(r_c, r_r, beta=2.0)
140    print(f"  r_chosen={r_c}, r_rejected={r_r} -> P(chosen)={prob:.4f}, Loss={loss:.4f}")
141
142
143# ============================================
144# 4. PPO ๊ฐœ๋… ์‹œ๋ฎฌ๋ ˆ์ด์…˜
145# ============================================
146print("\n[4] PPO ๊ฐœ๋… ์‹œ๋ฎฌ๋ ˆ์ด์…˜")
147print("-" * 40)
148
149class SimplePPOSimulator:
150    """PPO ๊ฐœ๋… ์‹œ๋ฎฌ๋ ˆ์ด์…˜"""
151
152    def __init__(self, clip_epsilon=0.2, kl_coef=0.1):
153        self.clip_epsilon = clip_epsilon
154        self.kl_coef = kl_coef
155        self.policy_history = []
156
157    def compute_ratio(self, new_prob, old_prob):
158        """ํ™•๋ฅ  ๋น„์œจ ๊ณ„์‚ฐ"""
159        return new_prob / (old_prob + 1e-10)
160
161    def clip_ratio(self, ratio):
162        """PPO ํด๋ฆฌํ•‘"""
163        return np.clip(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
164
165    def compute_ppo_objective(self, ratio, advantage):
166        """PPO ๋ชฉ์  ํ•จ์ˆ˜"""
167        clipped_ratio = self.clip_ratio(ratio)
168        obj1 = ratio * advantage
169        obj2 = clipped_ratio * advantage
170        return min(obj1, obj2)  # ๋ณด์ˆ˜์  ์—…๋ฐ์ดํŠธ
171
172    def compute_kl_penalty(self, new_prob, old_prob):
173        """KL ํŽ˜๋„ํ‹ฐ"""
174        kl = new_prob * np.log(new_prob / (old_prob + 1e-10) + 1e-10)
175        return self.kl_coef * kl
176
177# ํ…Œ์ŠคํŠธ
178ppo = SimplePPOSimulator()
179print("PPO ํด๋ฆฌํ•‘ ์˜ˆ์‹œ:")
180
181test_cases = [
182    (0.8, 0.5, 1.0),   # ํ™•๋ฅ  ์ฆ๊ฐ€, ์–‘์˜ ์–ด๋“œ๋ฐดํ‹ฐ์ง€
183    (0.3, 0.5, 1.0),   # ํ™•๋ฅ  ๊ฐ์†Œ, ์–‘์˜ ์–ด๋“œ๋ฐดํ‹ฐ์ง€
184    (0.8, 0.5, -1.0),  # ํ™•๋ฅ  ์ฆ๊ฐ€, ์Œ์˜ ์–ด๋“œ๋ฐดํ‹ฐ์ง€
185]
186
187for new_p, old_p, adv in test_cases:
188    ratio = ppo.compute_ratio(new_p, old_p)
189    clipped = ppo.clip_ratio(ratio)
190    obj = ppo.compute_ppo_objective(ratio, adv)
191    print(f"  new_p={new_p}, old_p={old_p}, adv={adv}")
192    print(f"    ratio={ratio:.2f}, clipped={clipped:.2f}, objective={obj:.2f}")
193
194
195# ============================================
196# 5. SFT ๋ฐ์ดํ„ฐ ํ˜•์‹
197# ============================================
198print("\n[5] SFT (Supervised Fine-Tuning) ๋ฐ์ดํ„ฐ")
199print("-" * 40)
200
201# Alpaca ํ˜•์‹
202alpaca_data = [
203    {
204        "instruction": "๋‹ค์Œ ํ…์ŠคํŠธ๋ฅผ ์š”์•ฝํ•˜์„ธ์š”.",
205        "input": "์ธ๊ณต์ง€๋Šฅ์€ ์ปดํ“จํ„ฐ ๊ณผํ•™์˜ ํ•œ ๋ถ„์•ผ๋กœ, ์ธ๊ฐ„์˜ ํ•™์Šต๋Šฅ๋ ฅ, ์ถ”๋ก ๋Šฅ๋ ฅ, "
206                 "์ง€๊ฐ๋Šฅ๋ ฅ, ์ž์—ฐ์–ธ์–ด ์ดํ•ด๋Šฅ๋ ฅ ๋“ฑ์„ ์ปดํ“จํ„ฐ ํ”„๋กœ๊ทธ๋žจ์œผ๋กœ ์‹คํ˜„ํ•œ ๊ธฐ์ˆ ์ด๋‹ค.",
207        "output": "์ธ๊ณต์ง€๋Šฅ์€ ์ธ๊ฐ„์˜ ์ง€์  ๋Šฅ๋ ฅ์„ ์ปดํ“จํ„ฐ๋กœ ๊ตฌํ˜„ํ•œ ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค."
208    },
209    {
210        "instruction": "๋‹ค์Œ ๋ฌธ์žฅ์„ ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜์„ธ์š”.",
211        "input": "์•ˆ๋…•ํ•˜์„ธ์š”, ์˜ค๋Š˜ ๋‚ ์”จ๊ฐ€ ์ข‹๋„ค์š”.",
212        "output": "Hello, the weather is nice today."
213    }
214]
215
216print("Alpaca ํ˜•์‹ ์˜ˆ์‹œ:")
217for item in alpaca_data:
218    print(f"\n  Instruction: {item['instruction']}")
219    print(f"  Input: {item['input'][:40]}...")
220    print(f"  Output: {item['output']}")
221
222# ChatML ํ˜•์‹
223chatml_example = """
224<|system|>
225You are a helpful assistant.
226<|user|>
227What is the capital of Korea?
228<|assistant|>
229The capital of South Korea is Seoul.
230"""
231
232print(f"\nChatML ํ˜•์‹ ์˜ˆ์‹œ:{chatml_example}")
233
234
235# ============================================
236# 6. Constitutional AI ์‹œ๋ฎฌ๋ ˆ์ด์…˜
237# ============================================
238print("\n[6] Constitutional AI ์‹œ๋ฎฌ๋ ˆ์ด์…˜")
239print("-" * 40)
240
241class ConstitutionalAI:
242    """Constitutional AI ์‹œ๋ฎฌ๋ ˆ์ด์…˜"""
243
244    def __init__(self):
245        self.constitution = [
246            "์‘๋‹ต์€ ๋„์›€์ด ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.",
247            "์‘๋‹ต์€ ํ•ด๋กœ์šด ๋‚ด์šฉ์„ ํฌํ•จํ•˜์ง€ ์•Š์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.",
248            "์‘๋‹ต์€ ์ •์งํ•˜๊ณ  ์‚ฌ์‹ค์— ๊ธฐ๋ฐ˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.",
249            "์ฐจ๋ณ„์ ์ด๊ฑฐ๋‚˜ ํŽธ๊ฒฌ ์žˆ๋Š” ๋‚ด์šฉ์„ ํฌํ•จํ•˜์ง€ ์•Š์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค."
250        ]
251
252    def check_principles(self, response):
253        """์›์น™ ์œ„๋ฐ˜ ํ™•์ธ (๊ฐ„๋‹จํ•œ ๊ทœ์น™ ๊ธฐ๋ฐ˜)"""
254        violations = []
255
256        # ํ•ด๋กœ์šด ํ‚ค์›Œ๋“œ ์ฒดํฌ
257        harmful_words = ["ํญ๋ ฅ", "์œ„ํ—˜ํ•œ", "๋ถˆ๋ฒ•"]
258        if any(word in response for word in harmful_words):
259            violations.append("ํ•ด๋กœ์šด ๋‚ด์šฉ ํฌํ•จ ๊ฐ€๋Šฅ")
260
261        # ๋„ˆ๋ฌด ์งง์€ ์‘๋‹ต
262        if len(response) < 20:
263            violations.append("์ถฉ๋ถ„ํžˆ ๋„์›€์ด ๋˜์ง€ ์•Š์Œ")
264
265        return violations
266
267    def critique(self, prompt, response):
268        """์‘๋‹ต ๋น„ํ‰"""
269        violations = self.check_principles(response)
270
271        critique = f"ํ”„๋กฌํ”„ํŠธ: {prompt}\n์‘๋‹ต: {response}\n\n์›์น™ ๊ฒ€ํ† :\n"
272        for i, principle in enumerate(self.constitution, 1):
273            critique += f"  {i}. {principle}\n"
274
275        if violations:
276            critique += f"\n์œ„๋ฐ˜ ์‚ฌํ•ญ:\n"
277            for v in violations:
278                critique += f"  - {v}\n"
279        else:
280            critique += "\n๋ชจ๋“  ์›์น™ ์ค€์ˆ˜"
281
282        return critique, violations
283
284    def revise(self, response, violations):
285        """์‘๋‹ต ์ˆ˜์ • (์‹œ๋ฎฌ๋ ˆ์ด์…˜)"""
286        revised = response
287        if "์ถฉ๋ถ„ํžˆ ๋„์›€์ด ๋˜์ง€ ์•Š์Œ" in violations:
288            revised = response + " ์ถ”๊ฐ€์ ์ธ ์„ค๋ช…์ด ํ•„์š”ํ•˜์‹œ๋ฉด ๋ง์”€ํ•ด ์ฃผ์„ธ์š”."
289        return revised
290
291
292# ํ…Œ์ŠคํŠธ
293cai = ConstitutionalAI()
294
295test_responses = [
296    ("ํŒŒ์ด์ฌ ๋ฐฐ์šฐ๋Š” ๋ฐฉ๋ฒ•?", "์ฑ…์„ ์ฝ์œผ์„ธ์š”."),
297    ("์šด๋™์˜ ํšจ๊ณผ?", "์šด๋™์€ ๊ฑด๊ฐ•์— ๋งค์šฐ ์ข‹์Šต๋‹ˆ๋‹ค. ์‹ฌํ˜ˆ๊ด€ ๊ธฐ๋Šฅ ๊ฐœ์„ , ์ฒด์ค‘ ๊ด€๋ฆฌ, ์ •์‹  ๊ฑด๊ฐ• ํ–ฅ์ƒ ๋“ฑ ๋‹ค์–‘ํ•œ ์ด์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค."),
298]
299
300print("Constitutional AI ๊ฒ€ํ† :")
301for prompt, response in test_responses:
302    critique, violations = cai.critique(prompt, response)
303    print(f"\n{'-'*30}")
304    print(critique)
305    if violations:
306        revised = cai.revise(response, violations)
307        print(f"์ˆ˜์ •๋œ ์‘๋‹ต: {revised}")
308
309
310# ============================================
311# 7. TRL ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ๋ฒ• (์ฝ”๋“œ๋งŒ)
312# ============================================
313print("\n[7] TRL ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ฝ”๋“œ ์˜ˆ์‹œ")
314print("-" * 40)
315
316trl_code = '''
317# SFT (Supervised Fine-Tuning)
318from trl import SFTTrainer
319from transformers import TrainingArguments
320
321trainer = SFTTrainer(
322    model=model,
323    train_dataset=dataset,
324    formatting_func=format_instruction,
325    max_seq_length=1024,
326    args=TrainingArguments(
327        output_dir="./sft_model",
328        num_train_epochs=3,
329        per_device_train_batch_size=4,
330        learning_rate=2e-5,
331    ),
332)
333trainer.train()
334
335# DPO (Direct Preference Optimization)
336from trl import DPOTrainer, DPOConfig
337
338dpo_config = DPOConfig(
339    beta=0.1,  # ์˜จ๋„ ํŒŒ๋ผ๋ฏธํ„ฐ
340    loss_type="sigmoid",
341    max_length=512,
342)
343
344trainer = DPOTrainer(
345    model=model,
346    ref_model=ref_model,
347    args=dpo_config,
348    train_dataset=preference_dataset,  # prompt, chosen, rejected
349    tokenizer=tokenizer,
350)
351trainer.train()
352
353# PPO (Proximal Policy Optimization)
354from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
355
356ppo_config = PPOConfig(
357    learning_rate=1.41e-5,
358    batch_size=16,
359    ppo_epochs=4,
360    target_kl=0.1,
361)
362
363model = AutoModelForCausalLMWithValueHead.from_pretrained("./sft_model")
364
365ppo_trainer = PPOTrainer(
366    config=ppo_config,
367    model=model,
368    ref_model=ref_model,
369    tokenizer=tokenizer,
370)
371
372# ํ•™์Šต ๋ฃจํ”„
373for batch in dataloader:
374    query_tensors = tokenize(batch["prompt"])
375    response_tensors = ppo_trainer.generate(query_tensors)
376    rewards = reward_model(query_tensors, response_tensors)
377    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
378'''
379print(trl_code)
380
381
382# ============================================
383# 8. Reward Model ํ•™์Šต (์ฝ”๋“œ๋งŒ)
384# ============================================
385print("\n[8] Reward Model ํ•™์Šต ์ฝ”๋“œ")
386print("-" * 40)
387
388reward_code = '''
389from transformers import AutoModelForSequenceClassification, TrainingArguments
390from trl import RewardTrainer
391
392# Reward Model (๋ถ„๋ฅ˜ ํ—ค๋“œ ์ถ”๊ฐ€)
393reward_model = AutoModelForSequenceClassification.from_pretrained(
394    "meta-llama/Llama-2-7b-hf",
395    num_labels=1  # ์Šค์นผ๋ผ ์ถœ๋ ฅ
396)
397
398# ํ•™์Šต
399training_args = TrainingArguments(
400    output_dir="./reward_model",
401    num_train_epochs=1,
402    per_device_train_batch_size=4,
403    learning_rate=1e-5,
404)
405
406trainer = RewardTrainer(
407    model=reward_model,
408    args=training_args,
409    train_dataset=preference_dataset,
410    tokenizer=tokenizer,
411)
412trainer.train()
413
414# ๋ณด์ƒ ์ ์ˆ˜ ๊ณ„์‚ฐ
415def get_reward(prompt, response):
416    text = f"### Prompt: {prompt}\\n### Response: {response}"
417    inputs = tokenizer(text, return_tensors="pt")
418    with torch.no_grad():
419        reward = reward_model(**inputs).logits.squeeze().item()
420    return reward
421'''
422print(reward_code)
423
424
425# ============================================
426# ์ •๋ฆฌ
427# ============================================
428print("\n" + "=" * 60)
429print("RLHF ์ •๋ฆฌ")
430print("=" * 60)
431
432summary = """
433RLHF ํŒŒ์ดํ”„๋ผ์ธ:
434
4351. SFT (Supervised Fine-Tuning)
436   - ๊ณ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ๋กœ ๊ธฐ๋ณธ ๋Šฅ๋ ฅ ํ•™์Šต
437   - ํ˜•์‹: instruction, input, output
438
4392. Reward Model ํ•™์Šต
440   - ์„ ํ˜ธ๋„ ๋ฐ์ดํ„ฐ๋กœ ๋ณด์ƒ ํ•จ์ˆ˜ ํ•™์Šต
441   - ํ˜•์‹: prompt, chosen, rejected
442
4433. PPO (๊ฐ•ํ™”ํ•™์Šต)
444   - Reward Model๋กœ ์ •์ฑ… ์ตœ์ ํ™”
445   - KL ํŽ˜๋„ํ‹ฐ๋กœ ๊ธฐ์ค€ ๋ชจ๋ธ๊ณผ์˜ ๊ฑฐ๋ฆฌ ์ œํ•œ
446
4474. DPO (Direct Preference Optimization)
448   - Reward Model ์—†์ด ์ง์ ‘ ์„ ํ˜ธ๋„ ํ•™์Šต
449   - L = -log(sigmoid(ฮฒ * (log ฯ€(y_w) - log ฯ€(y_l))))
450
4515. Constitutional AI
452   - ์›์น™ ๊ธฐ๋ฐ˜ ์ž๊ธฐ ๋น„ํ‰ ๋ฐ ์ˆ˜์ •
453   - ์•ˆ์ „์„ฑ ํ–ฅ์ƒ
454
455์ •๋ ฌ ๋ฐฉ๋ฒ• ์„ ํƒ:
456- ๊ฐ„๋‹จํ•œ ์ •๋ ฌ: DPO (์ถ”์ฒœ)
457- ๋ณต์žกํ•œ ์ •๋ ฌ: RLHF (PPO)
458- ์•ˆ์ „์„ฑ ์ค‘์š”: Constitutional AI
459"""
460print(summary)