1"""
213. ๋ชจ๋ธ ์์ํ (Model Quantization) ์์
3
4INT8/INT4 ์์ํ, bitsandbytes, GPTQ, AWQ ์ค์ต
5"""
6
7import numpy as np
8
9print("=" * 60)
10print("๋ชจ๋ธ ์์ํ (Model Quantization)")
11print("=" * 60)
12
13
14# ============================================
15# 1. ๊ธฐ๋ณธ ์์ํ ์ดํด
16# ============================================
17print("\n[1] ๊ธฐ๋ณธ ์์ํ ๊ฐ๋
")
18print("-" * 40)
19
20def quantize_symmetric(tensor, bits=8):
21 """๋์นญ ์์ํ (Symmetric Quantization)"""
22 qmin = -(2 ** (bits - 1))
23 qmax = 2 ** (bits - 1) - 1
24
25 # ์ค์ผ์ผ ๊ณ์ฐ
26 abs_max = np.abs(tensor).max()
27 scale = abs_max / qmax if abs_max != 0 else 1.0
28
29 # ์์ํ
30 quantized = np.round(tensor / scale).astype(np.int8)
31 quantized = np.clip(quantized, qmin, qmax)
32
33 return quantized, scale
34
35def dequantize(quantized, scale):
36 """์ญ์์ํ"""
37 return quantized.astype(np.float32) * scale
38
39
40# ํ
์คํธ
41original = np.array([0.5, -1.2, 0.3, 2.1, -0.8, 0.0], dtype=np.float32)
42print(f"์๋ณธ ํ
์: {original}")
43
44quantized, scale = quantize_symmetric(original, bits=8)
45print(f"์์ํ๋จ (INT8): {quantized}")
46print(f"์ค์ผ์ผ: {scale:.6f}")
47
48recovered = dequantize(quantized, scale)
49print(f"๋ณต์๋จ: {recovered}")
50
51error = np.abs(original - recovered).mean()
52print(f"ํ๊ท ์์ํ ์ค์ฐจ: {error:.6f}")
53
54
55# ============================================
56# 2. ๋น๋์นญ ์์ํ
57# ============================================
58print("\n[2] ๋น๋์นญ ์์ํ")
59print("-" * 40)
60
61def quantize_asymmetric(tensor, bits=8):
62 """๋น๋์นญ ์์ํ (Asymmetric Quantization)"""
63 qmin = 0
64 qmax = 2 ** bits - 1
65
66 min_val = tensor.min()
67 max_val = tensor.max()
68
69 scale = (max_val - min_val) / (qmax - qmin) if max_val != min_val else 1.0
70 zero_point = round(-min_val / scale) if scale != 0 else 0
71
72 quantized = np.round(tensor / scale + zero_point).astype(np.uint8)
73 quantized = np.clip(quantized, qmin, qmax)
74
75 return quantized, scale, zero_point
76
77def dequantize_asymmetric(quantized, scale, zero_point):
78 """๋น๋์นญ ์ญ์์ํ"""
79 return (quantized.astype(np.float32) - zero_point) * scale
80
81
82# ํ
์คํธ
83asym_quantized, asym_scale, zero_point = quantize_asymmetric(original, bits=8)
84print(f"๋น๋์นญ ์์ํ (UINT8): {asym_quantized}")
85print(f"์ค์ผ์ผ: {asym_scale:.6f}, Zero Point: {zero_point}")
86
87asym_recovered = dequantize_asymmetric(asym_quantized, asym_scale, zero_point)
88print(f"๋ณต์๋จ: {asym_recovered}")
89
90
91# ============================================
92# 3. ๊ทธ๋ฃน๋ณ ์์ํ
93# ============================================
94print("\n[3] ๊ทธ๋ฃน๋ณ ์์ํ (Group Quantization)")
95print("-" * 40)
96
97def group_quantize(tensor, group_size=4, bits=4):
98 """๊ทธ๋ฃน๋ณ ์์ํ - ์ ํ๋ ํฅ์"""
99 flat = tensor.flatten()
100 pad_size = (group_size - len(flat) % group_size) % group_size
101 if pad_size > 0:
102 flat = np.pad(flat, (0, pad_size))
103
104 groups = flat.reshape(-1, group_size)
105 quantized_groups = []
106 scales = []
107
108 qmax = 2 ** (bits - 1) - 1
109 qmin = -(2 ** (bits - 1))
110
111 for group in groups:
112 abs_max = np.abs(group).max()
113 scale = abs_max / qmax if abs_max != 0 else 1.0
114 q = np.round(group / scale).astype(np.int8)
115 q = np.clip(q, qmin, qmax)
116 quantized_groups.append(q)
117 scales.append(scale)
118
119 return np.array(quantized_groups), np.array(scales)
120
121def group_dequantize(quantized_groups, scales):
122 """๊ทธ๋ฃน๋ณ ์ญ์์ํ"""
123 recovered = []
124 for q, s in zip(quantized_groups, scales):
125 recovered.append(q.astype(np.float32) * s)
126 return np.concatenate(recovered)
127
128
129# ํ
์คํธ
130larger_tensor = np.random.randn(16).astype(np.float32)
131print(f"์๋ณธ (16๊ฐ): {larger_tensor[:8]}...")
132
133g_quantized, g_scales = group_quantize(larger_tensor, group_size=4, bits=4)
134print(f"๊ทธ๋ฃน ์: {len(g_scales)}, ๊ทธ๋ฃน ํฌ๊ธฐ: 4")
135print(f"์ค์ผ์ผ๋ค: {g_scales}")
136
137g_recovered = group_dequantize(g_quantized, g_scales)
138g_error = np.abs(larger_tensor - g_recovered).mean()
139print(f"๊ทธ๋ฃน ์์ํ ํ๊ท ์ค์ฐจ: {g_error:.6f}")
140
141
142# ============================================
143# 4. ๋นํธ ์ ๋ฐ๋ ๋น๊ต
144# ============================================
145print("\n[4] ๋นํธ ์ ๋ฐ๋ ๋น๊ต")
146print("-" * 40)
147
148def compare_bit_precision(tensor):
149 """๋ค์ํ ๋นํธ ์ ๋ฐ๋ ๋น๊ต"""
150 results = {}
151
152 for bits in [8, 4, 2]:
153 q, s = quantize_symmetric(tensor, bits=bits)
154 r = dequantize(q, s)
155 error = np.abs(tensor - r).mean()
156 results[f"INT{bits}"] = {
157 "error": error,
158 "range": (-(2**(bits-1)), 2**(bits-1)-1)
159 }
160
161 return results
162
163comparison = compare_bit_precision(original)
164print("๋นํธ๋ณ ์์ํ ๋น๊ต:")
165for name, result in comparison.items():
166 print(f" {name}: ์ค์ฐจ={result['error']:.6f}, ๋ฒ์={result['range']}")
167
168
169# ============================================
170# 5. bitsandbytes ์์ (์ฝ๋๋ง)
171# ============================================
172print("\n[5] bitsandbytes ์ฌ์ฉ๋ฒ (์ฝ๋ ์์)")
173print("-" * 40)
174
175bnb_code = '''
176# bitsandbytes 8๋นํธ ์์ํ
177from transformers import AutoModelForCausalLM, AutoTokenizer
178
179model_8bit = AutoModelForCausalLM.from_pretrained(
180 "meta-llama/Llama-2-7b-hf",
181 load_in_8bit=True,
182 device_map="auto"
183)
184
185# bitsandbytes 4๋นํธ ์์ํ (NF4)
186from transformers import BitsAndBytesConfig
187import torch
188
189bnb_config = BitsAndBytesConfig(
190 load_in_4bit=True,
191 bnb_4bit_quant_type="nf4", # Normal Float 4
192 bnb_4bit_compute_dtype=torch.bfloat16,
193 bnb_4bit_use_double_quant=True # ์ด์ค ์์ํ
194)
195
196model_4bit = AutoModelForCausalLM.from_pretrained(
197 "meta-llama/Llama-2-7b-hf",
198 quantization_config=bnb_config,
199 device_map="auto"
200)
201
202print(f"4bit ๋ชจ๋ธ ๋ฉ๋ชจ๋ฆฌ: {model_4bit.get_memory_footprint() / 1e9:.2f} GB")
203'''
204print(bnb_code)
205
206
207# ============================================
208# 6. GPTQ ์์ (์ฝ๋๋ง)
209# ============================================
210print("\n[6] GPTQ ์์ํ (์ฝ๋ ์์)")
211print("-" * 40)
212
213gptq_code = '''
214from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
215
216# GPTQ ์ค์
217gptq_config = GPTQConfig(
218 bits=4,
219 group_size=128,
220 desc_act=True,
221 dataset=calibration_data,
222 tokenizer=tokenizer
223)
224
225# ์์ํ
226model = AutoModelForCausalLM.from_pretrained(
227 "meta-llama/Llama-2-7b-hf",
228 quantization_config=gptq_config,
229 device_map="auto"
230)
231
232model.save_pretrained("./llama-2-7b-gptq-4bit")
233
234# ์ฌ์ ์์ํ ๋ชจ๋ธ ๋ก๋
235model = AutoModelForCausalLM.from_pretrained(
236 "TheBloke/Llama-2-7B-GPTQ",
237 device_map="auto"
238)
239'''
240print(gptq_code)
241
242
243# ============================================
244# 7. AWQ ์์ (์ฝ๋๋ง)
245# ============================================
246print("\n[7] AWQ ์์ํ (์ฝ๋ ์์)")
247print("-" * 40)
248
249awq_code = '''
250from awq import AutoAWQForCausalLM
251from transformers import AutoTokenizer
252
253# ๋ชจ๋ธ ๋ก๋
254model = AutoAWQForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
255tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
256
257# AWQ ์์ํ ์ค์
258quant_config = {
259 "zero_point": True,
260 "q_group_size": 128,
261 "w_bit": 4,
262 "version": "GEMM"
263}
264
265# ์์ํ
266model.quantize(tokenizer, quant_config=quant_config)
267model.save_quantized("./llama-2-7b-awq")
268
269# AWQ ๋ชจ๋ธ ์ถ๋ก
270model = AutoAWQForCausalLM.from_quantized(
271 "./llama-2-7b-awq",
272 fuse_layers=True # ๋ ์ด์ด ํจ์ ์ผ๋ก ์๋ ํฅ์
273)
274'''
275print(awq_code)
276
277
278# ============================================
279# 8. QLoRA ์์ (์ฝ๋๋ง)
280# ============================================
281print("\n[8] QLoRA ํ์ธํ๋ (์ฝ๋ ์์)")
282print("-" * 40)
283
284qlora_code = '''
285from transformers import AutoModelForCausalLM, BitsAndBytesConfig
286from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
287import torch
288
289# 4๋นํธ ์์ํ ์ค์
290bnb_config = BitsAndBytesConfig(
291 load_in_4bit=True,
292 bnb_4bit_quant_type="nf4",
293 bnb_4bit_compute_dtype=torch.bfloat16,
294 bnb_4bit_use_double_quant=True
295)
296
297# ๋ชจ๋ธ ๋ก๋
298model = AutoModelForCausalLM.from_pretrained(
299 "meta-llama/Llama-2-7b-hf",
300 quantization_config=bnb_config,
301 device_map="auto"
302)
303
304# k-bit ํ์ต ์ค๋น
305model = prepare_model_for_kbit_training(model)
306
307# LoRA ์ค์
308lora_config = LoraConfig(
309 r=16,
310 lora_alpha=32,
311 target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
312 lora_dropout=0.05,
313 bias="none",
314 task_type="CAUSAL_LM"
315)
316
317# LoRA ์ ์ฉ
318model = get_peft_model(model, lora_config)
319model.print_trainable_parameters()
320# ์ถ๋ ฅ: trainable params: ~0.1%
321'''
322print(qlora_code)
323
324
325# ============================================
326# 9. ์์ํ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ ์๋ฎฌ๋ ์ด์
327# ============================================
328print("\n[9] ์์ํ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ ์๋ฎฌ๋ ์ด์
")
329print("-" * 40)
330
331def estimate_model_size(params_billions, bits):
332 """๋ชจ๋ธ ํฌ๊ธฐ ์ถ์ (GB)"""
333 bytes_per_param = bits / 8
334 size_gb = params_billions * 1e9 * bytes_per_param / (1024**3)
335 return size_gb
336
337model_sizes = {
338 "7B": 7,
339 "13B": 13,
340 "70B": 70,
341}
342
343precisions = {
344 "FP32": 32,
345 "FP16": 16,
346 "INT8": 8,
347 "INT4": 4,
348}
349
350print("๋ชจ๋ธ ํฌ๊ธฐ ์ถ์ (GB):")
351print("-" * 60)
352header = "Model\t" + "\t".join(precisions.keys())
353print(header)
354print("-" * 60)
355
356for model_name, params in model_sizes.items():
357 sizes = [f"{estimate_model_size(params, bits):.1f}" for bits in precisions.values()]
358 print(f"{model_name}\t" + "\t".join(sizes))
359
360
361# ============================================
362# ์ ๋ฆฌ
363# ============================================
364print("\n" + "=" * 60)
365print("์์ํ ์ ๋ฆฌ")
366print("=" * 60)
367
368summary = """
369์์ํ ํต์ฌ ๊ฐ๋
:
370
3711. ๋์นญ ์์ํ:
372 - scale = max(|x|) / (2^(bits-1) - 1)
373 - x_q = round(x / scale)
374 - x' = x_q * scale
375
3762. ๋น๋์นญ ์์ํ:
377 - scale = (max - min) / (2^bits - 1)
378 - zero_point = round(-min / scale)
379 - x_q = round(x / scale + zero_point)
380
3813. ์์ํ ๋ฐฉ๋ฒ ๋น๊ต:
382 - bitsandbytes: ๋น ๋ฅธ ์ ์ฉ, ๋์ ์์ํ
383 - GPTQ: ๋์ ํ์ง, ์บ๋ฆฌ๋ธ๋ ์ด์
ํ์
384 - AWQ: ๋น ๋ฅธ ์์ํ, ํ์ฑํ ๊ธฐ๋ฐ
385 - QLoRA: ์์ํ + LoRA ํ์ธํ๋
386
3874. ์ ํ ๊ฐ์ด๋:
388 - ํ๋กํ ํ์ดํ: bitsandbytes (load_in_8bit)
389 - ๋ฉ๋ชจ๋ฆฌ ์ ํ: bitsandbytes (load_in_4bit)
390 - ํ๋ก๋์
: GPTQ ๋๋ AWQ
391 - ํ์ธํ๋: QLoRA
392"""
393print(summary)