1"""
214. RLHF์ LLM ์ ๋ ฌ (Alignment) ์์
3
4Reward Model, PPO, DPO, Constitutional AI ์ค์ต
5"""
6
7import numpy as np
8import random
9
10print("=" * 60)
11print("RLHF์ LLM ์ ๋ ฌ (Alignment)")
12print("=" * 60)
13
14
15# ============================================
16# 1. ์ ํธ๋ ๋ฐ์ดํฐ ์ดํด
17# ============================================
18print("\n[1] ์ ํธ๋ ๋ฐ์ดํฐ ํ์")
19print("-" * 40)
20
21# ์ ํธ๋ ๋ฐ์ดํฐ ์์
22preference_data = [
23 {
24 "prompt": "์ธ๊ณต์ง๋ฅ์ด๋ ๋ฌด์์ธ๊ฐ์?",
25 "chosen": "์ธ๊ณต์ง๋ฅ(AI)์ ์ปดํจํฐ ์์คํ
์ด ์ธ๊ฐ์ ์ง๋ฅ์ ๋ชจ๋ฐฉํ์ฌ ํ์ต, ์ถ๋ก , "
26 "๋ฌธ์ ํด๊ฒฐ ๋ฑ์ ์์
์ ์ํํ๋ ๊ธฐ์ ์
๋๋ค. ๋จธ์ ๋ฌ๋, ๋ฅ๋ฌ๋, "
27 "์์ฐ์ด ์ฒ๋ฆฌ ๋ฑ ๋ค์ํ ๋ถ์ผ๋ฅผ ํฌํจํฉ๋๋ค.",
28 "rejected": "AI๋ ์ปดํจํฐ๊ฐ ๋๋ํด์ง๋ ๊ฒ์
๋๋ค."
29 },
30 {
31 "prompt": "ํ์ด์ฌ์ ์ฅ์ ์?",
32 "chosen": "ํ์ด์ฌ์ ์ฃผ์ ์ฅ์ ์ 1) ์ฝ๊ธฐ ์ฌ์ด ๋ฌธ๋ฒ, 2) ํ๋ถํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ, "
33 "3) ๋ค์ํ ๋ถ์ผ ์ ์ฉ ๊ฐ๋ฅ, 4) ํ๋ฐํ ์ปค๋ฎค๋ํฐ์
๋๋ค.",
34 "rejected": "ํ์ด์ฌ์ ์ข์ ์ธ์ด์
๋๋ค."
35 },
36 {
37 "prompt": "์ด๋์ ํจ๊ณผ๋?",
38 "chosen": "๊ท์น์ ์ธ ์ด๋์ ์ฌํ๊ด ๊ฑด๊ฐ ๊ฐ์ , ์ฒด์ค ๊ด๋ฆฌ, ๊ทผ๋ ฅ ๊ฐํ, "
39 "์ ์ ๊ฑด๊ฐ ํฅ์, ์๋ฉด ์ง ๊ฐ์ ๋ฑ ๋ค์ํ ํจ๊ณผ๊ฐ ์์ต๋๋ค.",
40 "rejected": "์ด๋ํ๋ฉด ๊ฑด๊ฐํด์ง๋๋ค."
41 }
42]
43
44print("์ ํธ๋ ๋ฐ์ดํฐ ์์:")
45for i, data in enumerate(preference_data):
46 print(f"\n{i+1}. ํ๋กฌํํธ: {data['prompt']}")
47 print(f" ์ ํธ ์๋ต: {data['chosen'][:50]}...")
48 print(f" ๋น์ ํธ ์๋ต: {data['rejected']}")
49
50
51# ============================================
52# 2. ๊ฐ๋จํ Reward Model ์๋ฎฌ๋ ์ด์
53# ============================================
54print("\n[2] ๊ฐ๋จํ Reward Model")
55print("-" * 40)
56
57class SimpleRewardModel:
58 """๊ฐ๋จํ ๊ท์น ๊ธฐ๋ฐ Reward Model (์๋ฎฌ๋ ์ด์
์ฉ)"""
59
60 def __init__(self):
61 self.positive_factors = {
62 "length": 0.3, # ์ ์ ํ ๊ธธ์ด
63 "detail": 0.3, # ์์ธํจ
64 "structure": 0.2, # ๊ตฌ์กฐํ
65 "politeness": 0.2 # ์ ์คํจ
66 }
67
68 def compute_reward(self, prompt, response):
69 """์๋ต์ ๋ํ ๋ณด์ ์ ์ ๊ณ์ฐ"""
70 score = 0.0
71
72 # 1. ๊ธธ์ด ์ ์ (50-300์ ์ต์ )
73 length = len(response)
74 if 50 <= length <= 300:
75 score += self.positive_factors["length"]
76 elif length > 300:
77 score += self.positive_factors["length"] * 0.5
78
79 # 2. ์์ธํจ (์ซ์, ์์ ํฌํจ)
80 if any(c.isdigit() for c in response):
81 score += self.positive_factors["detail"] * 0.5
82 if "์๋ฅผ ๋ค์ด" in response or "์์" in response:
83 score += self.positive_factors["detail"] * 0.5
84
85 # 3. ๊ตฌ์กฐํ (์ผํ, ๋ง์นจํ ์ฌ์ฉ)
86 if response.count(',') >= 2:
87 score += self.positive_factors["structure"]
88
89 # 4. ์ ์คํจ
90 polite_words = ["์
๋๋ค", "์ต๋๋ค", "๋ฉ๋๋ค"]
91 if any(word in response for word in polite_words):
92 score += self.positive_factors["politeness"]
93
94 return score
95
96# ํ
์คํธ
97reward_model = SimpleRewardModel()
98
99print("Reward Model ํ
์คํธ:")
100for data in preference_data:
101 chosen_reward = reward_model.compute_reward(data["prompt"], data["chosen"])
102 rejected_reward = reward_model.compute_reward(data["prompt"], data["rejected"])
103 print(f"\nํ๋กฌํํธ: {data['prompt']}")
104 print(f" ์ ํธ ์๋ต ์ ์: {chosen_reward:.2f}")
105 print(f" ๋น์ ํธ ์๋ต ์ ์: {rejected_reward:.2f}")
106 print(f" ์ ๋ ฌ ์ฌ๋ถ: {'OK' if chosen_reward > rejected_reward else 'FAIL'}")
107
108
109# ============================================
110# 3. Bradley-Terry ๋ชจ๋ธ (DPO ๊ธฐ๋ฐ)
111# ============================================
112print("\n[3] Bradley-Terry ๋ชจ๋ธ (์ ํธ๋ ํ๋ฅ )")
113print("-" * 40)
114
115def bradley_terry_probability(reward_chosen, reward_rejected, beta=1.0):
116 """
117 Bradley-Terry ๋ชจ๋ธ๋ก ์ ํธ ํ๋ฅ ๊ณ์ฐ
118
119 P(chosen > rejected) = sigmoid(beta * (r_chosen - r_rejected))
120 """
121 diff = reward_chosen - reward_rejected
122 prob = 1 / (1 + np.exp(-beta * diff))
123 return prob
124
125def dpo_loss(reward_chosen, reward_rejected, beta=0.1):
126 """
127 DPO ์์ค ํจ์ (๊ฐ๋จํ ๋ฒ์ )
128
129 L = -log(sigmoid(beta * (r_chosen - r_rejected)))
130 """
131 prob = bradley_terry_probability(reward_chosen, reward_rejected, beta)
132 loss = -np.log(prob + 1e-10)
133 return loss
134
135# ํ
์คํธ
136print("Bradley-Terry ์ ํธ ํ๋ฅ :")
137for r_c, r_r in [(0.8, 0.3), (0.5, 0.5), (0.3, 0.7)]:
138 prob = bradley_terry_probability(r_c, r_r, beta=2.0)
139 loss = dpo_loss(r_c, r_r, beta=2.0)
140 print(f" r_chosen={r_c}, r_rejected={r_r} -> P(chosen)={prob:.4f}, Loss={loss:.4f}")
141
142
143# ============================================
144# 4. PPO ๊ฐ๋
์๋ฎฌ๋ ์ด์
145# ============================================
146print("\n[4] PPO ๊ฐ๋
์๋ฎฌ๋ ์ด์
")
147print("-" * 40)
148
149class SimplePPOSimulator:
150 """PPO ๊ฐ๋
์๋ฎฌ๋ ์ด์
"""
151
152 def __init__(self, clip_epsilon=0.2, kl_coef=0.1):
153 self.clip_epsilon = clip_epsilon
154 self.kl_coef = kl_coef
155 self.policy_history = []
156
157 def compute_ratio(self, new_prob, old_prob):
158 """ํ๋ฅ ๋น์จ ๊ณ์ฐ"""
159 return new_prob / (old_prob + 1e-10)
160
161 def clip_ratio(self, ratio):
162 """PPO ํด๋ฆฌํ"""
163 return np.clip(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
164
165 def compute_ppo_objective(self, ratio, advantage):
166 """PPO ๋ชฉ์ ํจ์"""
167 clipped_ratio = self.clip_ratio(ratio)
168 obj1 = ratio * advantage
169 obj2 = clipped_ratio * advantage
170 return min(obj1, obj2) # ๋ณด์์ ์
๋ฐ์ดํธ
171
172 def compute_kl_penalty(self, new_prob, old_prob):
173 """KL ํ๋ํฐ"""
174 kl = new_prob * np.log(new_prob / (old_prob + 1e-10) + 1e-10)
175 return self.kl_coef * kl
176
177# ํ
์คํธ
178ppo = SimplePPOSimulator()
179print("PPO ํด๋ฆฌํ ์์:")
180
181test_cases = [
182 (0.8, 0.5, 1.0), # ํ๋ฅ ์ฆ๊ฐ, ์์ ์ด๋๋ฐดํฐ์ง
183 (0.3, 0.5, 1.0), # ํ๋ฅ ๊ฐ์, ์์ ์ด๋๋ฐดํฐ์ง
184 (0.8, 0.5, -1.0), # ํ๋ฅ ์ฆ๊ฐ, ์์ ์ด๋๋ฐดํฐ์ง
185]
186
187for new_p, old_p, adv in test_cases:
188 ratio = ppo.compute_ratio(new_p, old_p)
189 clipped = ppo.clip_ratio(ratio)
190 obj = ppo.compute_ppo_objective(ratio, adv)
191 print(f" new_p={new_p}, old_p={old_p}, adv={adv}")
192 print(f" ratio={ratio:.2f}, clipped={clipped:.2f}, objective={obj:.2f}")
193
194
195# ============================================
196# 5. SFT ๋ฐ์ดํฐ ํ์
197# ============================================
198print("\n[5] SFT (Supervised Fine-Tuning) ๋ฐ์ดํฐ")
199print("-" * 40)
200
201# Alpaca ํ์
202alpaca_data = [
203 {
204 "instruction": "๋ค์ ํ
์คํธ๋ฅผ ์์ฝํ์ธ์.",
205 "input": "์ธ๊ณต์ง๋ฅ์ ์ปดํจํฐ ๊ณผํ์ ํ ๋ถ์ผ๋ก, ์ธ๊ฐ์ ํ์ต๋ฅ๋ ฅ, ์ถ๋ก ๋ฅ๋ ฅ, "
206 "์ง๊ฐ๋ฅ๋ ฅ, ์์ฐ์ธ์ด ์ดํด๋ฅ๋ ฅ ๋ฑ์ ์ปดํจํฐ ํ๋ก๊ทธ๋จ์ผ๋ก ์คํํ ๊ธฐ์ ์ด๋ค.",
207 "output": "์ธ๊ณต์ง๋ฅ์ ์ธ๊ฐ์ ์ง์ ๋ฅ๋ ฅ์ ์ปดํจํฐ๋ก ๊ตฌํํ ๊ธฐ์ ์
๋๋ค."
208 },
209 {
210 "instruction": "๋ค์ ๋ฌธ์ฅ์ ์์ด๋ก ๋ฒ์ญํ์ธ์.",
211 "input": "์๋
ํ์ธ์, ์ค๋ ๋ ์จ๊ฐ ์ข๋ค์.",
212 "output": "Hello, the weather is nice today."
213 }
214]
215
216print("Alpaca ํ์ ์์:")
217for item in alpaca_data:
218 print(f"\n Instruction: {item['instruction']}")
219 print(f" Input: {item['input'][:40]}...")
220 print(f" Output: {item['output']}")
221
222# ChatML ํ์
223chatml_example = """
224<|system|>
225You are a helpful assistant.
226<|user|>
227What is the capital of Korea?
228<|assistant|>
229The capital of South Korea is Seoul.
230"""
231
232print(f"\nChatML ํ์ ์์:{chatml_example}")
233
234
235# ============================================
236# 6. Constitutional AI ์๋ฎฌ๋ ์ด์
237# ============================================
238print("\n[6] Constitutional AI ์๋ฎฌ๋ ์ด์
")
239print("-" * 40)
240
241class ConstitutionalAI:
242 """Constitutional AI ์๋ฎฌ๋ ์ด์
"""
243
244 def __init__(self):
245 self.constitution = [
246 "์๋ต์ ๋์์ด ๋์ด์ผ ํฉ๋๋ค.",
247 "์๋ต์ ํด๋ก์ด ๋ด์ฉ์ ํฌํจํ์ง ์์์ผ ํฉ๋๋ค.",
248 "์๋ต์ ์ ์งํ๊ณ ์ฌ์ค์ ๊ธฐ๋ฐํด์ผ ํฉ๋๋ค.",
249 "์ฐจ๋ณ์ ์ด๊ฑฐ๋ ํธ๊ฒฌ ์๋ ๋ด์ฉ์ ํฌํจํ์ง ์์์ผ ํฉ๋๋ค."
250 ]
251
252 def check_principles(self, response):
253 """์์น ์๋ฐ ํ์ธ (๊ฐ๋จํ ๊ท์น ๊ธฐ๋ฐ)"""
254 violations = []
255
256 # ํด๋ก์ด ํค์๋ ์ฒดํฌ
257 harmful_words = ["ํญ๋ ฅ", "์ํํ", "๋ถ๋ฒ"]
258 if any(word in response for word in harmful_words):
259 violations.append("ํด๋ก์ด ๋ด์ฉ ํฌํจ ๊ฐ๋ฅ")
260
261 # ๋๋ฌด ์งง์ ์๋ต
262 if len(response) < 20:
263 violations.append("์ถฉ๋ถํ ๋์์ด ๋์ง ์์")
264
265 return violations
266
267 def critique(self, prompt, response):
268 """์๋ต ๋นํ"""
269 violations = self.check_principles(response)
270
271 critique = f"ํ๋กฌํํธ: {prompt}\n์๋ต: {response}\n\n์์น ๊ฒํ :\n"
272 for i, principle in enumerate(self.constitution, 1):
273 critique += f" {i}. {principle}\n"
274
275 if violations:
276 critique += f"\n์๋ฐ ์ฌํญ:\n"
277 for v in violations:
278 critique += f" - {v}\n"
279 else:
280 critique += "\n๋ชจ๋ ์์น ์ค์"
281
282 return critique, violations
283
284 def revise(self, response, violations):
285 """์๋ต ์์ (์๋ฎฌ๋ ์ด์
)"""
286 revised = response
287 if "์ถฉ๋ถํ ๋์์ด ๋์ง ์์" in violations:
288 revised = response + " ์ถ๊ฐ์ ์ธ ์ค๋ช
์ด ํ์ํ์๋ฉด ๋ง์ํด ์ฃผ์ธ์."
289 return revised
290
291
292# ํ
์คํธ
293cai = ConstitutionalAI()
294
295test_responses = [
296 ("ํ์ด์ฌ ๋ฐฐ์ฐ๋ ๋ฐฉ๋ฒ?", "์ฑ
์ ์ฝ์ผ์ธ์."),
297 ("์ด๋์ ํจ๊ณผ?", "์ด๋์ ๊ฑด๊ฐ์ ๋งค์ฐ ์ข์ต๋๋ค. ์ฌํ๊ด ๊ธฐ๋ฅ ๊ฐ์ , ์ฒด์ค ๊ด๋ฆฌ, ์ ์ ๊ฑด๊ฐ ํฅ์ ๋ฑ ๋ค์ํ ์ด์ ์ด ์์ต๋๋ค."),
298]
299
300print("Constitutional AI ๊ฒํ :")
301for prompt, response in test_responses:
302 critique, violations = cai.critique(prompt, response)
303 print(f"\n{'-'*30}")
304 print(critique)
305 if violations:
306 revised = cai.revise(response, violations)
307 print(f"์์ ๋ ์๋ต: {revised}")
308
309
310# ============================================
311# 7. TRL ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ฌ์ฉ๋ฒ (์ฝ๋๋ง)
312# ============================================
313print("\n[7] TRL ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ฝ๋ ์์")
314print("-" * 40)
315
316trl_code = '''
317# SFT (Supervised Fine-Tuning)
318from trl import SFTTrainer
319from transformers import TrainingArguments
320
321trainer = SFTTrainer(
322 model=model,
323 train_dataset=dataset,
324 formatting_func=format_instruction,
325 max_seq_length=1024,
326 args=TrainingArguments(
327 output_dir="./sft_model",
328 num_train_epochs=3,
329 per_device_train_batch_size=4,
330 learning_rate=2e-5,
331 ),
332)
333trainer.train()
334
335# DPO (Direct Preference Optimization)
336from trl import DPOTrainer, DPOConfig
337
338dpo_config = DPOConfig(
339 beta=0.1, # ์จ๋ ํ๋ผ๋ฏธํฐ
340 loss_type="sigmoid",
341 max_length=512,
342)
343
344trainer = DPOTrainer(
345 model=model,
346 ref_model=ref_model,
347 args=dpo_config,
348 train_dataset=preference_dataset, # prompt, chosen, rejected
349 tokenizer=tokenizer,
350)
351trainer.train()
352
353# PPO (Proximal Policy Optimization)
354from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
355
356ppo_config = PPOConfig(
357 learning_rate=1.41e-5,
358 batch_size=16,
359 ppo_epochs=4,
360 target_kl=0.1,
361)
362
363model = AutoModelForCausalLMWithValueHead.from_pretrained("./sft_model")
364
365ppo_trainer = PPOTrainer(
366 config=ppo_config,
367 model=model,
368 ref_model=ref_model,
369 tokenizer=tokenizer,
370)
371
372# ํ์ต ๋ฃจํ
373for batch in dataloader:
374 query_tensors = tokenize(batch["prompt"])
375 response_tensors = ppo_trainer.generate(query_tensors)
376 rewards = reward_model(query_tensors, response_tensors)
377 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
378'''
379print(trl_code)
380
381
382# ============================================
383# 8. Reward Model ํ์ต (์ฝ๋๋ง)
384# ============================================
385print("\n[8] Reward Model ํ์ต ์ฝ๋")
386print("-" * 40)
387
388reward_code = '''
389from transformers import AutoModelForSequenceClassification, TrainingArguments
390from trl import RewardTrainer
391
392# Reward Model (๋ถ๋ฅ ํค๋ ์ถ๊ฐ)
393reward_model = AutoModelForSequenceClassification.from_pretrained(
394 "meta-llama/Llama-2-7b-hf",
395 num_labels=1 # ์ค์นผ๋ผ ์ถ๋ ฅ
396)
397
398# ํ์ต
399training_args = TrainingArguments(
400 output_dir="./reward_model",
401 num_train_epochs=1,
402 per_device_train_batch_size=4,
403 learning_rate=1e-5,
404)
405
406trainer = RewardTrainer(
407 model=reward_model,
408 args=training_args,
409 train_dataset=preference_dataset,
410 tokenizer=tokenizer,
411)
412trainer.train()
413
414# ๋ณด์ ์ ์ ๊ณ์ฐ
415def get_reward(prompt, response):
416 text = f"### Prompt: {prompt}\\n### Response: {response}"
417 inputs = tokenizer(text, return_tensors="pt")
418 with torch.no_grad():
419 reward = reward_model(**inputs).logits.squeeze().item()
420 return reward
421'''
422print(reward_code)
423
424
425# ============================================
426# ์ ๋ฆฌ
427# ============================================
428print("\n" + "=" * 60)
429print("RLHF ์ ๋ฆฌ")
430print("=" * 60)
431
432summary = """
433RLHF ํ์ดํ๋ผ์ธ:
434
4351. SFT (Supervised Fine-Tuning)
436 - ๊ณ ํ์ง ๋ฐ์ดํฐ๋ก ๊ธฐ๋ณธ ๋ฅ๋ ฅ ํ์ต
437 - ํ์: instruction, input, output
438
4392. Reward Model ํ์ต
440 - ์ ํธ๋ ๋ฐ์ดํฐ๋ก ๋ณด์ ํจ์ ํ์ต
441 - ํ์: prompt, chosen, rejected
442
4433. PPO (๊ฐํํ์ต)
444 - Reward Model๋ก ์ ์ฑ
์ต์ ํ
445 - KL ํ๋ํฐ๋ก ๊ธฐ์ค ๋ชจ๋ธ๊ณผ์ ๊ฑฐ๋ฆฌ ์ ํ
446
4474. DPO (Direct Preference Optimization)
448 - Reward Model ์์ด ์ง์ ์ ํธ๋ ํ์ต
449 - L = -log(sigmoid(ฮฒ * (log ฯ(y_w) - log ฯ(y_l))))
450
4515. Constitutional AI
452 - ์์น ๊ธฐ๋ฐ ์๊ธฐ ๋นํ ๋ฐ ์์
453 - ์์ ์ฑ ํฅ์
454
455์ ๋ ฌ ๋ฐฉ๋ฒ ์ ํ:
456- ๊ฐ๋จํ ์ ๋ ฌ: DPO (์ถ์ฒ)
457- ๋ณต์กํ ์ ๋ ฌ: RLHF (PPO)
458- ์์ ์ฑ ์ค์: Constitutional AI
459"""
460print(summary)