13_quantization_example.py

Download
python 394 lines 10.2 KB
  1"""
  213. ๋ชจ๋ธ ์–‘์žํ™” (Model Quantization) ์˜ˆ์ œ
  3
  4INT8/INT4 ์–‘์žํ™”, bitsandbytes, GPTQ, AWQ ์‹ค์Šต
  5"""
  6
  7import numpy as np
  8
  9print("=" * 60)
 10print("๋ชจ๋ธ ์–‘์žํ™” (Model Quantization)")
 11print("=" * 60)
 12
 13
 14# ============================================
 15# 1. ๊ธฐ๋ณธ ์–‘์žํ™” ์ดํ•ด
 16# ============================================
 17print("\n[1] ๊ธฐ๋ณธ ์–‘์žํ™” ๊ฐœ๋…")
 18print("-" * 40)
 19
 20def quantize_symmetric(tensor, bits=8):
 21    """๋Œ€์นญ ์–‘์žํ™” (Symmetric Quantization)"""
 22    qmin = -(2 ** (bits - 1))
 23    qmax = 2 ** (bits - 1) - 1
 24
 25    # ์Šค์ผ€์ผ ๊ณ„์‚ฐ
 26    abs_max = np.abs(tensor).max()
 27    scale = abs_max / qmax if abs_max != 0 else 1.0
 28
 29    # ์–‘์žํ™”
 30    quantized = np.round(tensor / scale).astype(np.int8)
 31    quantized = np.clip(quantized, qmin, qmax)
 32
 33    return quantized, scale
 34
 35def dequantize(quantized, scale):
 36    """์—ญ์–‘์žํ™”"""
 37    return quantized.astype(np.float32) * scale
 38
 39
 40# ํ…Œ์ŠคํŠธ
 41original = np.array([0.5, -1.2, 0.3, 2.1, -0.8, 0.0], dtype=np.float32)
 42print(f"์›๋ณธ ํ…์„œ: {original}")
 43
 44quantized, scale = quantize_symmetric(original, bits=8)
 45print(f"์–‘์žํ™”๋จ (INT8): {quantized}")
 46print(f"์Šค์ผ€์ผ: {scale:.6f}")
 47
 48recovered = dequantize(quantized, scale)
 49print(f"๋ณต์›๋จ: {recovered}")
 50
 51error = np.abs(original - recovered).mean()
 52print(f"ํ‰๊ท  ์–‘์žํ™” ์˜ค์ฐจ: {error:.6f}")
 53
 54
 55# ============================================
 56# 2. ๋น„๋Œ€์นญ ์–‘์žํ™”
 57# ============================================
 58print("\n[2] ๋น„๋Œ€์นญ ์–‘์žํ™”")
 59print("-" * 40)
 60
 61def quantize_asymmetric(tensor, bits=8):
 62    """๋น„๋Œ€์นญ ์–‘์žํ™” (Asymmetric Quantization)"""
 63    qmin = 0
 64    qmax = 2 ** bits - 1
 65
 66    min_val = tensor.min()
 67    max_val = tensor.max()
 68
 69    scale = (max_val - min_val) / (qmax - qmin) if max_val != min_val else 1.0
 70    zero_point = round(-min_val / scale) if scale != 0 else 0
 71
 72    quantized = np.round(tensor / scale + zero_point).astype(np.uint8)
 73    quantized = np.clip(quantized, qmin, qmax)
 74
 75    return quantized, scale, zero_point
 76
 77def dequantize_asymmetric(quantized, scale, zero_point):
 78    """๋น„๋Œ€์นญ ์—ญ์–‘์žํ™”"""
 79    return (quantized.astype(np.float32) - zero_point) * scale
 80
 81
 82# ํ…Œ์ŠคํŠธ
 83asym_quantized, asym_scale, zero_point = quantize_asymmetric(original, bits=8)
 84print(f"๋น„๋Œ€์นญ ์–‘์žํ™” (UINT8): {asym_quantized}")
 85print(f"์Šค์ผ€์ผ: {asym_scale:.6f}, Zero Point: {zero_point}")
 86
 87asym_recovered = dequantize_asymmetric(asym_quantized, asym_scale, zero_point)
 88print(f"๋ณต์›๋จ: {asym_recovered}")
 89
 90
 91# ============================================
 92# 3. ๊ทธ๋ฃน๋ณ„ ์–‘์žํ™”
 93# ============================================
 94print("\n[3] ๊ทธ๋ฃน๋ณ„ ์–‘์žํ™” (Group Quantization)")
 95print("-" * 40)
 96
 97def group_quantize(tensor, group_size=4, bits=4):
 98    """๊ทธ๋ฃน๋ณ„ ์–‘์žํ™” - ์ •ํ™•๋„ ํ–ฅ์ƒ"""
 99    flat = tensor.flatten()
100    pad_size = (group_size - len(flat) % group_size) % group_size
101    if pad_size > 0:
102        flat = np.pad(flat, (0, pad_size))
103
104    groups = flat.reshape(-1, group_size)
105    quantized_groups = []
106    scales = []
107
108    qmax = 2 ** (bits - 1) - 1
109    qmin = -(2 ** (bits - 1))
110
111    for group in groups:
112        abs_max = np.abs(group).max()
113        scale = abs_max / qmax if abs_max != 0 else 1.0
114        q = np.round(group / scale).astype(np.int8)
115        q = np.clip(q, qmin, qmax)
116        quantized_groups.append(q)
117        scales.append(scale)
118
119    return np.array(quantized_groups), np.array(scales)
120
121def group_dequantize(quantized_groups, scales):
122    """๊ทธ๋ฃน๋ณ„ ์—ญ์–‘์žํ™”"""
123    recovered = []
124    for q, s in zip(quantized_groups, scales):
125        recovered.append(q.astype(np.float32) * s)
126    return np.concatenate(recovered)
127
128
129# ํ…Œ์ŠคํŠธ
130larger_tensor = np.random.randn(16).astype(np.float32)
131print(f"์›๋ณธ (16๊ฐœ): {larger_tensor[:8]}...")
132
133g_quantized, g_scales = group_quantize(larger_tensor, group_size=4, bits=4)
134print(f"๊ทธ๋ฃน ์ˆ˜: {len(g_scales)}, ๊ทธ๋ฃน ํฌ๊ธฐ: 4")
135print(f"์Šค์ผ€์ผ๋“ค: {g_scales}")
136
137g_recovered = group_dequantize(g_quantized, g_scales)
138g_error = np.abs(larger_tensor - g_recovered).mean()
139print(f"๊ทธ๋ฃน ์–‘์žํ™” ํ‰๊ท  ์˜ค์ฐจ: {g_error:.6f}")
140
141
142# ============================================
143# 4. ๋น„ํŠธ ์ •๋ฐ€๋„ ๋น„๊ต
144# ============================================
145print("\n[4] ๋น„ํŠธ ์ •๋ฐ€๋„ ๋น„๊ต")
146print("-" * 40)
147
148def compare_bit_precision(tensor):
149    """๋‹ค์–‘ํ•œ ๋น„ํŠธ ์ •๋ฐ€๋„ ๋น„๊ต"""
150    results = {}
151
152    for bits in [8, 4, 2]:
153        q, s = quantize_symmetric(tensor, bits=bits)
154        r = dequantize(q, s)
155        error = np.abs(tensor - r).mean()
156        results[f"INT{bits}"] = {
157            "error": error,
158            "range": (-(2**(bits-1)), 2**(bits-1)-1)
159        }
160
161    return results
162
163comparison = compare_bit_precision(original)
164print("๋น„ํŠธ๋ณ„ ์–‘์žํ™” ๋น„๊ต:")
165for name, result in comparison.items():
166    print(f"  {name}: ์˜ค์ฐจ={result['error']:.6f}, ๋ฒ”์œ„={result['range']}")
167
168
169# ============================================
170# 5. bitsandbytes ์˜ˆ์ œ (์ฝ”๋“œ๋งŒ)
171# ============================================
172print("\n[5] bitsandbytes ์‚ฌ์šฉ๋ฒ• (์ฝ”๋“œ ์˜ˆ์‹œ)")
173print("-" * 40)
174
175bnb_code = '''
176# bitsandbytes 8๋น„ํŠธ ์–‘์žํ™”
177from transformers import AutoModelForCausalLM, AutoTokenizer
178
179model_8bit = AutoModelForCausalLM.from_pretrained(
180    "meta-llama/Llama-2-7b-hf",
181    load_in_8bit=True,
182    device_map="auto"
183)
184
185# bitsandbytes 4๋น„ํŠธ ์–‘์žํ™” (NF4)
186from transformers import BitsAndBytesConfig
187import torch
188
189bnb_config = BitsAndBytesConfig(
190    load_in_4bit=True,
191    bnb_4bit_quant_type="nf4",           # Normal Float 4
192    bnb_4bit_compute_dtype=torch.bfloat16,
193    bnb_4bit_use_double_quant=True       # ์ด์ค‘ ์–‘์žํ™”
194)
195
196model_4bit = AutoModelForCausalLM.from_pretrained(
197    "meta-llama/Llama-2-7b-hf",
198    quantization_config=bnb_config,
199    device_map="auto"
200)
201
202print(f"4bit ๋ชจ๋ธ ๋ฉ”๋ชจ๋ฆฌ: {model_4bit.get_memory_footprint() / 1e9:.2f} GB")
203'''
204print(bnb_code)
205
206
207# ============================================
208# 6. GPTQ ์˜ˆ์ œ (์ฝ”๋“œ๋งŒ)
209# ============================================
210print("\n[6] GPTQ ์–‘์žํ™” (์ฝ”๋“œ ์˜ˆ์‹œ)")
211print("-" * 40)
212
213gptq_code = '''
214from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
215
216# GPTQ ์„ค์ •
217gptq_config = GPTQConfig(
218    bits=4,
219    group_size=128,
220    desc_act=True,
221    dataset=calibration_data,
222    tokenizer=tokenizer
223)
224
225# ์–‘์žํ™”
226model = AutoModelForCausalLM.from_pretrained(
227    "meta-llama/Llama-2-7b-hf",
228    quantization_config=gptq_config,
229    device_map="auto"
230)
231
232model.save_pretrained("./llama-2-7b-gptq-4bit")
233
234# ์‚ฌ์ „ ์–‘์žํ™” ๋ชจ๋ธ ๋กœ๋“œ
235model = AutoModelForCausalLM.from_pretrained(
236    "TheBloke/Llama-2-7B-GPTQ",
237    device_map="auto"
238)
239'''
240print(gptq_code)
241
242
243# ============================================
244# 7. AWQ ์˜ˆ์ œ (์ฝ”๋“œ๋งŒ)
245# ============================================
246print("\n[7] AWQ ์–‘์žํ™” (์ฝ”๋“œ ์˜ˆ์‹œ)")
247print("-" * 40)
248
249awq_code = '''
250from awq import AutoAWQForCausalLM
251from transformers import AutoTokenizer
252
253# ๋ชจ๋ธ ๋กœ๋“œ
254model = AutoAWQForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
255tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
256
257# AWQ ์–‘์žํ™” ์„ค์ •
258quant_config = {
259    "zero_point": True,
260    "q_group_size": 128,
261    "w_bit": 4,
262    "version": "GEMM"
263}
264
265# ์–‘์žํ™”
266model.quantize(tokenizer, quant_config=quant_config)
267model.save_quantized("./llama-2-7b-awq")
268
269# AWQ ๋ชจ๋ธ ์ถ”๋ก 
270model = AutoAWQForCausalLM.from_quantized(
271    "./llama-2-7b-awq",
272    fuse_layers=True  # ๋ ˆ์ด์–ด ํ“จ์ „์œผ๋กœ ์†๋„ ํ–ฅ์ƒ
273)
274'''
275print(awq_code)
276
277
278# ============================================
279# 8. QLoRA ์˜ˆ์ œ (์ฝ”๋“œ๋งŒ)
280# ============================================
281print("\n[8] QLoRA ํŒŒ์ธํŠœ๋‹ (์ฝ”๋“œ ์˜ˆ์‹œ)")
282print("-" * 40)
283
284qlora_code = '''
285from transformers import AutoModelForCausalLM, BitsAndBytesConfig
286from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
287import torch
288
289# 4๋น„ํŠธ ์–‘์žํ™” ์„ค์ •
290bnb_config = BitsAndBytesConfig(
291    load_in_4bit=True,
292    bnb_4bit_quant_type="nf4",
293    bnb_4bit_compute_dtype=torch.bfloat16,
294    bnb_4bit_use_double_quant=True
295)
296
297# ๋ชจ๋ธ ๋กœ๋“œ
298model = AutoModelForCausalLM.from_pretrained(
299    "meta-llama/Llama-2-7b-hf",
300    quantization_config=bnb_config,
301    device_map="auto"
302)
303
304# k-bit ํ•™์Šต ์ค€๋น„
305model = prepare_model_for_kbit_training(model)
306
307# LoRA ์„ค์ •
308lora_config = LoraConfig(
309    r=16,
310    lora_alpha=32,
311    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
312    lora_dropout=0.05,
313    bias="none",
314    task_type="CAUSAL_LM"
315)
316
317# LoRA ์ ์šฉ
318model = get_peft_model(model, lora_config)
319model.print_trainable_parameters()
320# ์ถœ๋ ฅ: trainable params: ~0.1%
321'''
322print(qlora_code)
323
324
325# ============================================
326# 9. ์–‘์žํ™” ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ์‹œ๋ฎฌ๋ ˆ์ด์…˜
327# ============================================
328print("\n[9] ์–‘์žํ™” ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ์‹œ๋ฎฌ๋ ˆ์ด์…˜")
329print("-" * 40)
330
331def estimate_model_size(params_billions, bits):
332    """๋ชจ๋ธ ํฌ๊ธฐ ์ถ”์ • (GB)"""
333    bytes_per_param = bits / 8
334    size_gb = params_billions * 1e9 * bytes_per_param / (1024**3)
335    return size_gb
336
337model_sizes = {
338    "7B": 7,
339    "13B": 13,
340    "70B": 70,
341}
342
343precisions = {
344    "FP32": 32,
345    "FP16": 16,
346    "INT8": 8,
347    "INT4": 4,
348}
349
350print("๋ชจ๋ธ ํฌ๊ธฐ ์ถ”์ • (GB):")
351print("-" * 60)
352header = "Model\t" + "\t".join(precisions.keys())
353print(header)
354print("-" * 60)
355
356for model_name, params in model_sizes.items():
357    sizes = [f"{estimate_model_size(params, bits):.1f}" for bits in precisions.values()]
358    print(f"{model_name}\t" + "\t".join(sizes))
359
360
361# ============================================
362# ์ •๋ฆฌ
363# ============================================
364print("\n" + "=" * 60)
365print("์–‘์žํ™” ์ •๋ฆฌ")
366print("=" * 60)
367
368summary = """
369์–‘์žํ™” ํ•ต์‹ฌ ๊ฐœ๋…:
370
3711. ๋Œ€์นญ ์–‘์žํ™”:
372   - scale = max(|x|) / (2^(bits-1) - 1)
373   - x_q = round(x / scale)
374   - x' = x_q * scale
375
3762. ๋น„๋Œ€์นญ ์–‘์žํ™”:
377   - scale = (max - min) / (2^bits - 1)
378   - zero_point = round(-min / scale)
379   - x_q = round(x / scale + zero_point)
380
3813. ์–‘์žํ™” ๋ฐฉ๋ฒ• ๋น„๊ต:
382   - bitsandbytes: ๋น ๋ฅธ ์ ์šฉ, ๋™์  ์–‘์žํ™”
383   - GPTQ: ๋†’์€ ํ’ˆ์งˆ, ์บ˜๋ฆฌ๋ธŒ๋ ˆ์ด์…˜ ํ•„์š”
384   - AWQ: ๋น ๋ฅธ ์–‘์žํ™”, ํ™œ์„ฑํ™” ๊ธฐ๋ฐ˜
385   - QLoRA: ์–‘์žํ™” + LoRA ํŒŒ์ธํŠœ๋‹
386
3874. ์„ ํƒ ๊ฐ€์ด๋“œ:
388   - ํ”„๋กœํ† ํƒ€์ดํ•‘: bitsandbytes (load_in_8bit)
389   - ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ: bitsandbytes (load_in_4bit)
390   - ํ”„๋กœ๋•์…˜: GPTQ ๋˜๋Š” AWQ
391   - ํŒŒ์ธํŠœ๋‹: QLoRA
392"""
393print(summary)