LLaMA Family
LLaMA Family¶
ํ์ต ๋ชฉํ¶
- LLaMA 1/2/3์ ์ํคํ ์ฒ ์งํ ์ดํด
- RoPE, RMSNorm, SwiGLU ๋ฑ ํต์ฌ ๊ธฐ์ ์ต๋
- Grouped Query Attention (GQA) ๋ฉ์ปค๋์ฆ ํ์
- ์ค๋ฌด์์์ LLaMA ํ์ฉ๋ฒ ํ์ต
1. LLaMA ๊ฐ์¶
1.1 LLaMA์ ์์¶
LLaMA(Large Language Model Meta AI)๋ 2023๋ Meta๊ฐ ๊ณต๊ฐํ ์คํ์์ค LLM์ผ๋ก, Foundation Model ์ฐ๊ตฌ์ ๋ฏผ์ฃผํ๋ฅผ ์ด๋์์ต๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ LLaMA์ ์ญ์ฌ์ ์์ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Before LLaMA (2022): โ
โ โข ์ต๊ณ ์ฑ๋ฅ ๋ชจ๋ธ์ API๋ง ์ ๊ณต (GPT-3.5, PaLM) โ
โ โข ํ์ ์ฐ๊ตฌ์ฉ ๋ชจ๋ธ์ ์ฑ๋ฅ ๋ถ์กฑ โ
โ โข ์คํ์์ค ์ปค๋ฎค๋ํฐ์ LLM ์ ๊ทผ ์ ํ์ โ
โ โ
โ After LLaMA (2023): โ
โ โข ์ฐ๊ตฌ์๋ค์ด ์ต์ฒจ๋จ ๋ชจ๋ธ ์ง์ ์คํ ๊ฐ๋ฅ โ
โ โข Alpaca, Vicuna ๋ฑ ํ์ ๋ชจ๋ธ ํญ๋ฐ์ ์ฆ๊ฐ โ
โ โข LLM ์ฐ๊ตฌ ์๋ ๊ธ๊ฒฉํ ๊ฐ์ํ โ
โ โ
โ ํต์ฌ ๊ธฐ์ฌ: โ
โ โข Chinchilla ๊ท์น ์ ์ฉ (D=20N ์ด์) โ
โ โข ํจ์จ์ ์ํคํ
์ฒ ์ ํ ๊ฒ์ฆ โ
โ โข ํ์ต ๋ฐ์ดํฐ ๊ตฌ์ฑ ๊ณต๊ฐ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1.2 ๋ฒ์ ๋น๊ต¶
| ํน์ฑ | LLaMA 1 | LLaMA 2 | LLaMA 3 | LLaMA 3.1 | LLaMA 3.2 |
|---|---|---|---|---|---|
| ์ถ์ | 2023.02 | 2023.07 | 2024.04 | 2024.07 | 2024.09 |
| ํฌ๊ธฐ | 7/13/33/65B | 7/13/70B | 8/70B | 8/70/405B | 1/3/11/90B |
| ํ ํฐ | 1.4T | 2T | 15T+ | 15T+ | 15T+ |
| Context | 2K | 4K | 8K | 128K | 128K |
| License | ์ฐ๊ตฌ์ฉ | ์์ ์ (์กฐ๊ฑด๋ถ) | ์์ ์ (์ํ) | ์์ ์ (์ํ) | ์์ ์ (์ํ) |
| GQA | โ | โ (70B) | โ (์ ์ฒด) | โ (์ ์ฒด) | โ (์ ์ฒด) |
| ํน์ง | ๊ธฐ๋ณธ ์ํคํ ์ฒ | RLHF, Safety | ๊ฐ์ ๋ ์ถ๋ก | 128K ๋ค์ดํฐ๋ธ, Tool Use | ๋น์ ๋ชจ๋ธ ์ถ๊ฐ |
LLaMA 3.1/3.2 ์ฃผ์ ์ ๋ฐ์ดํธ (2024): - LLaMA 3.1: 128K ๋ค์ดํฐ๋ธ ์ปจํ ์คํธ, 405B ํ๋๊ทธ์ญ ๋ชจ๋ธ, Tool Use ๊ธฐ๋ฅ - LLaMA 3.2: ๊ฒฝ๋ ๋ชจ๋ธ(1B/3B)๊ณผ ๋น์ ๋ชจ๋ธ(11B/90B) ์ถ๊ฐ
2. LLaMA ์ํคํ ์ฒ¶
2.1 ํต์ฌ ๊ตฌ์ฑ ์์¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ LLaMA Architecture โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Input Tokens โ
โ โ โ
โ โผ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Token Embedding โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โผ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ RoPE (Rotary Position Embedding) โ โ ์์น ์ ๋ณด โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Transformer Block ร N โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ RMSNorm (Pre-normalization) โ โ โ LayerNorm ๋์ฒด โ
โ โ โ โ โ โ โ
โ โ โ Grouped Query Attention (GQA) โ โ โ KV Cache ํจ์จ โ
โ โ โ โ โ โ โ
โ โ โ Residual Connection โ โ โ
โ โ โ โ โ โ โ
โ โ โ RMSNorm โ โ โ
โ โ โ โ โ โ โ
โ โ โ SwiGLU FFN โ โ โ GELU ๋์ฒด โ
โ โ โ โ โ โ โ
โ โ โ Residual Connection โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โผ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ RMSNorm โ Linear โ Vocab โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โผ โ
โ Output Logits โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
2.2 ํ์ดํผํ๋ผ๋ฏธํฐ¶
"""
LLaMA ๋ชจ๋ธ ์ฌ์ ๋น๊ต
"""
LLAMA_CONFIGS = {
"llama-7b": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 32, # MHA (GQA ๋ฏธ์ฌ์ฉ)
"vocab_size": 32000,
"ffn_dim": 11008, # ์ฝ 2.67 ร dim
"context_length": 2048,
},
"llama-13b": {
"dim": 5120,
"n_layers": 40,
"n_heads": 40,
"n_kv_heads": 40,
"vocab_size": 32000,
"ffn_dim": 13824,
"context_length": 2048,
},
"llama-70b": {
"dim": 8192,
"n_layers": 80,
"n_heads": 64,
"n_kv_heads": 8, # GQA! 8๊ฐ KV heads
"vocab_size": 32000,
"ffn_dim": 28672,
"context_length": 4096,
},
"llama3-8b": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8, # GQA
"vocab_size": 128256, # ํ์ฅ๋ vocab
"ffn_dim": 14336,
"context_length": 8192,
},
"llama3-70b": {
"dim": 8192,
"n_layers": 80,
"n_heads": 64,
"n_kv_heads": 8, # GQA
"vocab_size": 128256,
"ffn_dim": 28672,
"context_length": 8192,
},
# LLaMA 3.1 (2024.07)
"llama3.1-8b": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim": 14336,
"context_length": 131072, # 128K ๋ค์ดํฐ๋ธ
},
"llama3.1-405b": {
"dim": 16384,
"n_layers": 126,
"n_heads": 128,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim": 53248,
"context_length": 131072, # 128K ๋ค์ดํฐ๋ธ
},
# LLaMA 3.2 (2024.09) - ๊ฒฝ๋ ํ
์คํธ ๋ชจ๋ธ
"llama3.2-1b": {
"dim": 2048,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim": 8192,
"context_length": 131072,
},
"llama3.2-3b": {
"dim": 3072,
"n_layers": 28,
"n_heads": 24,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim": 8192,
"context_length": 131072,
},
}
3. RoPE (Rotary Position Embedding)¶
3.1 ๊ฐ๋ ¶
RoPE๋ ์์น ์ ๋ณด๋ฅผ ํ์ ํ๋ ฌ๋ก ์ธ์ฝ๋ฉํ๋ ๋ฐฉ์์ ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Position Encoding ๋น๊ต โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 1. Sinusoidal (Transformer ์๋ณธ) โ
โ PE(pos, 2i) = sin(pos / 10000^(2i/d)) โ
โ PE(pos, 2i+1) = cos(pos / 10000^(2i/d)) โ
โ โ ์
๋ ฅ์ ๋ํจ (additive) โ
โ โ ์๋ ์์น ์ ๋ณด ์ฝํจ โ
โ โ
โ 2. Learned (BERT, GPT) โ
โ PE = Embedding(position) โ
โ โ ํ์ต๋ ๋ฒกํฐ โ
โ โ ํ์ต ์ค ๋ณธ ๊ธธ์ด ์ด์ ์ผ๋ฐํ ์ด๋ ค์ โ
โ โ
โ 3. RoPE (LLaMA) โ
โ R(ฮธ) = ํ์ ํ๋ ฌ, ฮธ = f(position) โ
โ q' = R(ฮธ_m) ร q, k' = R(ฮธ_n) ร k โ
โ q' ยท k' = q ยท k ร cos(ฮธ_m - ฮธ_n) โ
โ โ ์๋ ์์น ์์ฐ์ค๋ฝ๊ฒ ์ธ์ฝ๋ฉ โ
โ โ ๊ธธ์ด ์ธ์ฝ ๊ฐ๋ฅ (with modifications) โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
3.2 ์ํ์ ์ดํด¶
import torch
import math
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
"""
RoPE๋ฅผ ์ํ ๋ณต์์ ์ฃผํ์ ์ฌ์ ๊ณ์ฐ
Args:
dim: ์๋ฒ ๋ฉ ์ฐจ์ (head_dim)
seq_len: ์ต๋ ์ํ์ค ๊ธธ์ด
theta: ๊ธฐ๋ณธ ์ฃผํ์ (10000)
Returns:
freqs_cis: (seq_len, dim//2) ๋ณต์์ ํ
์
"""
# ์ฃผํ์ ๊ณ์ฐ: ฮธ_i = 1 / (theta^(2i/d))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
# ์์น๋ณ ๊ฐ๋: m * ฮธ_i
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs) # (seq_len, dim//2)
# ๋ณต์์ ํํ: e^(i*ฮธ) = cos(ฮธ) + i*sin(ฮธ)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(xq, xk, freqs_cis):
"""
Query์ Key์ RoPE ์ ์ฉ
Args:
xq: Query (batch, seq_len, n_heads, head_dim)
xk: Key (batch, seq_len, n_kv_heads, head_dim)
freqs_cis: ์ฌ์ ๊ณ์ฐ๋ ๋ณต์์ ์ฃผํ์
Returns:
ํ์ ๋ Query์ Key
"""
# ์ค์๋ฅผ ๋ณต์์๋ก ๋ณํ (์ธ์ ํ 2๊ฐ์ฉ ๋ฌถ์)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# ํ์ ์ ์ฉ (๋ณต์์ ๊ณฑ)
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim//2)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# ์์
batch, seq_len, n_heads, head_dim = 2, 128, 32, 128
xq = torch.randn(batch, seq_len, n_heads, head_dim)
xk = torch.randn(batch, seq_len, n_heads, head_dim)
freqs_cis = precompute_freqs_cis(head_dim, seq_len)
xq_rope, xk_rope = apply_rotary_emb(xq, xk, freqs_cis)
print(f"Output shape: {xq_rope.shape}") # (2, 128, 32, 128)
3.3 RoPE์ ์ฅ์ ¶
"""
RoPE์ ์ฅ์ :
1. ์๋ ์์น ์์ฐ์ค๋ฝ๊ฒ ์ธ์ฝ๋ฉ
- q_m ยท k_n โ cos(ฮธ_m - ฮธ_n)
- ์ ๋ ์์น๊ฐ ์๋ ์๋ ๊ฑฐ๋ฆฌ ์์กด
2. ์ธ์ฝ ๊ฐ๋ฅ์ฑ
- ํ์ต ์ ๋ณธ ๊ธธ์ด ์ด์์ผ๋ก ํ์ฅ ๊ฐ๋ฅ
- (๋จ, ์ฑ๋ฅ ์ ํ ์์ โ NTK, YaRN์ผ๋ก ๊ฐ์ )
3. ํจ์จ์ฑ
- ์ถ๊ฐ ํ๋ผ๋ฏธํฐ ์์
- Element-wise ์ฐ์ฐ์ผ๋ก ๋น ๋ฆ
4. ์ ํ Self-attention๊ณผ ํธํ
- ์ผ๋ถ ํจ์จ์ attention ๋ฐฉ์๊ณผ ๊ฒฐํฉ ๊ฐ๋ฅ
"""
4. RMSNorm¶
4.1 ๊ฐ๋ ¶
RMSNorm์ LayerNorm์ ๋จ์ํ ๋ฒ์ ์ผ๋ก, ํ๊ท ๊ณ์ฐ์ ์ ๊ฑฐํฉ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ LayerNorm vs RMSNorm โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ LayerNorm: โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ ฮผ = mean(x) โ
โ ฯ = std(x) โ
โ y = ฮณ ร (x - ฮผ) / ฯ + ฮฒ โ
โ โ
โ โข ํ๊ท ๋นผ๊ธฐ + ๋ถ์ฐ์ผ๋ก ๋๋๊ธฐ โ
โ โข ํ์ต ๊ฐ๋ฅํ scale(ฮณ)์ shift(ฮฒ) โ
โ โ
โ RMSNorm: โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ RMS(x) = sqrt(mean(x^2)) โ
โ y = ฮณ ร x / RMS(x) โ
โ โ
โ โข ํ๊ท ๋นผ๊ธฐ ์์ โ Re-centering ์ ๊ฑฐ โ
โ โข RMS๋ก๋ง ์ค์ผ์ผ๋ง โ
โ โข shift(ฮฒ) ์์ โ
โ โข ์ฐ์ฐ๋ ๊ฐ์, ์ฑ๋ฅ ์ ์ฌ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
4.2 ๊ตฌํ¶
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
๋
ผ๋ฌธ: https://arxiv.org/abs/1910.07467
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # scale parameter ฮณ
def _norm(self, x):
# RMS = sqrt(mean(x^2))
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# output = ฮณ ร (x / RMS(x))
output = self._norm(x.float()).type_as(x)
return output * self.weight
# LayerNorm๊ณผ ๋น๊ต
x = torch.randn(2, 10, 512)
layer_norm = nn.LayerNorm(512)
rms_norm = RMSNorm(512)
# ์ฐ์ฐ ์๊ฐ ๋น๊ต (RMSNorm์ด ์ฝ๊ฐ ๋น ๋ฆ)
import time
start = time.time()
for _ in range(1000):
_ = layer_norm(x)
print(f"LayerNorm: {time.time() - start:.4f}s")
start = time.time()
for _ in range(1000):
_ = rms_norm(x)
print(f"RMSNorm: {time.time() - start:.4f}s")
5. SwiGLU¶
5.1 ๊ฐ๋ ¶
SwiGLU๋ GLU(Gated Linear Unit)์ ๋ณํ์ผ๋ก, Swish ํ์ฑํ ํจ์๋ฅผ ์ฌ์ฉํฉ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ FFN ํ์ฑํ ํจ์ ๋น๊ต โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 1. ReLU FFN (Transformer ์๋ณธ): โ
โ FFN(x) = max(0, xWโ + bโ)Wโ + bโ โ
โ โข ๋จ์ํ์ง๋ง ์์ ์์ญ ์ ๋ณด ์์ค โ
โ โ
โ 2. GELU FFN (BERT, GPT): โ
โ FFN(x) = GELU(xWโ)Wโ โ
โ GELU(x) = x ร ฮฆ(x) (ฮฆ = CDF of normal) โ
โ โข ๋ถ๋๋ฌ์ด ํ์ฑํ, ์ฑ๋ฅ ํฅ์ โ
โ โ
โ 3. SwiGLU FFN (LLaMA): โ
โ FFN(x) = (Swish(xWโ) โ xV)Wโ โ
โ Swish(x) = x ร ฯ(x) (ฯ = sigmoid) โ
โ โ = element-wise multiplication โ
โ โ
โ โข Gating mechanism์ผ๋ก ์ ๋ณด ํ๋ฆ ์ ์ด โ
โ โข ๋ ๋ง์ ํ๋ผ๋ฏธํฐ, ๋ ์ข์ ์ฑ๋ฅ โ
โ โข 2/3 ร 4d hidden dim (ํ๋ผ๋ฏธํฐ ์ ์ ์ง) โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
5.2 ๊ตฌํ¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
"""
SwiGLU: Swish-Gated Linear Unit
FFN(x) = (Swish(xWโ) โ xV) Wโ
๋
ผ๋ฌธ: https://arxiv.org/abs/2002.05202
"""
def __init__(self, dim: int, hidden_dim: int = None, multiple_of: int = 256):
super().__init__()
# hidden_dim ๊ณ์ฐ: 2/3 ร 4d, 256์ ๋ฐฐ์๋ก ๋ฐ์ฌ๋ฆผ
if hidden_dim is None:
hidden_dim = int(2 * (4 * dim) / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # gate
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # down projection
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # up projection
def forward(self, x):
# SwiGLU: Swish(xWโ) โ (xWโ) โ Wโ
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# ๊ธฐ์กด FFN๊ณผ ๋น๊ต
class StandardFFN(nn.Module):
def __init__(self, dim: int, hidden_dim: int = None):
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
return self.w2(F.gelu(self.w1(x)))
# ํ๋ผ๋ฏธํฐ ์ ๋น๊ต
dim = 4096
swiglu = SwiGLU(dim) # 3๊ฐ์ linear: dimโhidden, dimโhidden, hiddenโdim
standard = StandardFFN(dim) # 2๊ฐ์ linear: dimโ4*dim, 4*dimโdim
print(f"SwiGLU params: {sum(p.numel() for p in swiglu.parameters()):,}")
print(f"Standard FFN params: {sum(p.numel() for p in standard.parameters()):,}")
# SwiGLU๊ฐ ์ฝ๊ฐ ๋ ๋ง์ง๋ง hidden_dim ์กฐ์ ์ผ๋ก ๋น์ทํ๊ฒ ๋ง์ถค
6. Grouped Query Attention (GQA)¶
6.1 ๊ฐ๋ ¶
GQA๋ Multi-Head Attention๊ณผ Multi-Query Attention์ ์ค๊ฐ ํํ์ ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Attention ๋ฐฉ์ ๋น๊ต โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 1. Multi-Head Attention (MHA): โ
โ Q heads: 32 โ K heads: 32 โ V heads: 32 โ
โ โข ๊ฐ Q head๊ฐ ๋
๋ฆฝ์ ์ธ KV head ์ฌ์ฉ โ
โ โข ๋ฉ๋ชจ๋ฆฌ: ๋ง์ (32 ร KV cache) โ
โ โ
โ 2. Multi-Query Attention (MQA): โ
โ Q heads: 32 โ K heads: 1 โ V heads: 1 โ
โ โข ๋ชจ๋ Q head๊ฐ ๊ฐ์ KV ๊ณต์ โ
โ โข ๋ฉ๋ชจ๋ฆฌ: ์ต์ (1 ร KV cache) โ
โ โข ํ์ง: MHA๋ณด๋ค ์ฝ๊ฐ ๋ฎ์ โ
โ โ
โ 3. Grouped Query Attention (GQA): โ
โ Q heads: 32 โ K heads: 8 โ V heads: 8 โ
โ โข Q heads๋ฅผ ๊ทธ๋ฃน์ผ๋ก ๋๋ KV ๊ณต์ โ
โ โข ์: 4๊ฐ์ Q head๊ฐ 1๊ฐ์ KV head ๊ณต์ โ
โ โข ๋ฉ๋ชจ๋ฆฌ: ์ค๊ฐ (8 ร KV cache) โ
โ โข ํ์ง: MHA์ ๊ฑฐ์ ๋์ผ โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ MHA โ MQA โ GQA โ โ
โ โ Q Q Q Q Q Q โ Q Q Q Q Q Q โ Q QโQ QโQ Q โ โ
โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โโโ โโโ โ โ โ
โ โ K K K K K K โ K โ K โ K โ K โ โ
โ โ V V V V V V โ V โ V โ V โ V โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
6.2 ๊ตฌํ¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention (GQA)
๋
ผ๋ฌธ: https://arxiv.org/abs/2305.13245
"""
def __init__(
self,
dim: int,
n_heads: int = 32,
n_kv_heads: int = 8, # KV heads ์ (< n_heads)
head_dim: int = None,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim or dim // n_heads
# Q heads > KV heads ๊ฒ์ฆ
assert n_heads % n_kv_heads == 0
self.n_rep = n_heads // n_kv_heads # ๊ฐ KV head๊ฐ ๋ด๋นํ๋ Q head ์
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
def forward(self, x, freqs_cis=None, mask=None, kv_cache=None):
batch, seq_len, _ = x.shape
# Q, K, V ๊ณ์ฐ
xq = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim)
xk = self.wk(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
xv = self.wv(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
# RoPE ์ ์ฉ (์๋ ๊ฒฝ์ฐ)
if freqs_cis is not None:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# KV Cache ์ฒ๋ฆฌ (์ถ๋ก ์)
if kv_cache is not None:
cache_k, cache_v = kv_cache
xk = torch.cat([cache_k, xk], dim=1)
xv = torch.cat([cache_v, xv], dim=1)
# KV heads ํ์ฅ: n_kv_heads โ n_heads
# (batch, seq, n_kv_heads, head_dim) โ (batch, seq, n_heads, head_dim)
xk = self._repeat_kv(xk)
xv = self._repeat_kv(xv)
# Attention ๊ณ์ฐ
xq = xq.transpose(1, 2) # (batch, n_heads, seq, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask
attn = F.softmax(scores, dim=-1)
output = torch.matmul(attn, xv)
# ๊ฒฐ๊ณผ ํฉ์น๊ธฐ
output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(output), (xk, xv)
def _repeat_kv(self, x):
"""KV heads๋ฅผ Q heads ์๋งํผ ๋ฐ๋ณต"""
if self.n_rep == 1:
return x
batch, seq_len, n_kv_heads, head_dim = x.shape
x = x[:, :, :, None, :].expand(batch, seq_len, n_kv_heads, self.n_rep, head_dim)
return x.reshape(batch, seq_len, n_kv_heads * self.n_rep, head_dim)
# ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๋น๊ต
def compare_kv_cache_memory(seq_len, batch_size=1, dtype_bytes=2):
"""KV cache ๋ฉ๋ชจ๋ฆฌ ๋น๊ต (FP16 ๊ธฐ์ค)"""
configs = {
"LLaMA-70B (MHA)": {"n_layers": 80, "n_kv_heads": 64, "head_dim": 128},
"LLaMA-70B (GQA)": {"n_layers": 80, "n_kv_heads": 8, "head_dim": 128},
}
for name, cfg in configs.items():
kv_mem = (2 * # K and V
batch_size *
cfg["n_layers"] *
seq_len *
cfg["n_kv_heads"] *
cfg["head_dim"] *
dtype_bytes)
print(f"{name}: {kv_mem / 1e9:.2f} GB for {seq_len} tokens")
compare_kv_cache_memory(4096)
# LLaMA-70B (MHA): 5.24 GB for 4096 tokens
# LLaMA-70B (GQA): 0.66 GB for 4096 tokens โ 8๋ฐฐ ์ ์ฝ!
7. LLaMA ์ค์ต¶
7.1 HuggingFace๋ก ์ฌ์ฉํ๊ธฐ¶
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# LLaMA 2 7B ๋ก๋
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# ํ
์คํธ ์์ฑ
def generate_text(prompt, max_new_tokens=100, temperature=0.7):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ์ฌ์ฉ ์์
prompt = "Explain the concept of machine learning in simple terms:"
response = generate_text(prompt)
print(response)
7.2 ์์ํ๋ก ํจ์จ์ ์ฌ์ฉ¶
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# 4-bit ์์ํ ์ค์
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, # ์ด์ค ์์ํ
)
# ์์ํ๋ ๋ชจ๋ธ ๋ก๋
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto"
)
# ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ํ์ธ
print(f"Model memory: {model.get_memory_footprint() / 1e9:.2f} GB")
# ์ฝ 4GB (FP16 ๋๋น 75% ์ ์ฝ)
7.3 LLaMA 3 ์ฌ์ฉ¶
# LLaMA 3 8B (128K ํ ํฌ๋์ด์ )
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # LLaMA 3๋ bfloat16 ๊ถ์ฅ
device_map="auto"
)
# LLaMA 3 ํน์ง:
# - 128K ํ ํฌ๋์ด์ (๋ ํจ์จ์ )
# - 8K ๊ธฐ๋ณธ ์ปจํ
์คํธ (128K๊น์ง ํ์ฅ ๊ฐ๋ฅ)
# - ๊ฐ์ ๋ ์ถ๋ก ๋ฅ๋ ฅ
prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
8. LLaMA 3.1/3.2 ์์ธ¶
8.1 LLaMA 3.1 (2024๋ 7์)¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ LLaMA 3.1 ์ฃผ์ ํน์ง โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 1. 128K ๋ค์ดํฐ๋ธ ์ปจํ
์คํธ โ
โ โข ํ์ต ์๋ถํฐ 128K ํ ํฐ ์ง์ โ
โ โข RoPE scaling ์์ด ๊ธด ๋ฌธ๋งฅ ์ฒ๋ฆฌ โ
โ โ
โ 2. 405B ํ๋๊ทธ์ญ ๋ชจ๋ธ โ
โ โข GPT-4 ์์ค ์ฑ๋ฅ โ
โ โข 126๊ฐ ๋ ์ด์ด, 16K ์๋ฒ ๋ฉ ์ฐจ์ โ
โ โ
โ 3. Tool Use ๊ธฐ๋ฅ โ
โ โข ํจ์ ํธ์ถ (Function Calling) โ
โ โข ์ฝ๋ ์ธํฐํ๋ฆฌํฐ โ
โ โข ๊ฒ์ ๋๊ตฌ ํตํฉ โ
โ โ
โ 4. ๋ค๊ตญ์ด ์ง์ ๊ฐํ โ
โ โข ์์ด, ๋
์ผ์ด, ํ๋์ค์ด, ์ดํ๋ฆฌ์์ด โ
โ โข ํฌ๋ฅดํฌ๊ฐ์ด, ํ๋์ด, ์คํ์ธ์ด, ํ๊ตญ์ด โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# LLaMA 3.1 Tool Use ์์
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Tool Use ํ์ (LLaMA 3.1 ํนํ)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"]
}
}
}
]
messages = [
{"role": "system", "content": "You are a helpful assistant with access to tools."},
{"role": "user", "content": "What's the weather in Seoul?"}
]
# Tool ํธ์ถ ์์ฑ
inputs = tokenizer.apply_chat_template(
messages,
tools=tools,
return_tensors="pt"
).to(model.device)
outputs = model.generate(inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
8.2 LLaMA 3.2 (2024๋ 9์)¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ LLaMA 3.2 ๋ชจ๋ธ ๋ผ์ธ์
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ ๊ฒฝ๋ ํ
์คํธ ๋ชจ๋ธ (on-device ์ต์ ํ): โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ LLaMA 3.2 1B: ๋ชจ๋ฐ์ผ/์์ง ๋๋ฐ์ด์ค์ฉ โ โ
โ โ LLaMA 3.2 3B: ๊ฒฝ๋ ์ ํ๋ฆฌ์ผ์ด์
์ฉ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ ๋น์ ๋ชจ๋ธ (๋ฉํฐ๋ชจ๋ฌ): โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ LLaMA 3.2 11B-Vision: ์ด๋ฏธ์ง ์ดํด โ โ
โ โ LLaMA 3.2 90B-Vision: ๊ณ ์ฑ๋ฅ ๋น์ ํ์คํฌ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ ํน์ง: โ
โ โข 1B/3B: 128K ์ปจํ
์คํธ, on-device ์ถ๋ก ๊ฐ๋ฅ โ
โ โข 11B/90B: ๋น์ ์ธ์ฝ๋ ํตํฉ, ์ด๋ฏธ์ง+ํ
์คํธ ์ฒ๋ฆฌ โ
โ โข Qualcomm, MediaTek ํ๋์จ์ด ์ต์ ํ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# LLaMA 3.2 Vision ์์
from transformers import MllamaForConditionalGeneration, AutoProcessor
from PIL import Image
import requests
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
# ์ด๋ฏธ์ง ๋ก๋
url = "https://example.com/image.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# ๋น์ ๋ํ
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "์ด ์ด๋ฏธ์ง์์ ๋ฌด์์ด ๋ณด์ด๋์?"}
]
}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=256)
print(processor.decode(output[0]))
์ ๋ฆฌ¶
LLaMA ํต์ฌ ๊ธฐ์ ¶
| ๊ธฐ์ | ํจ๊ณผ |
|---|---|
| RoPE | ์๋ ์์น ์ธ์ฝ๋ฉ, ๊ธธ์ด ํ์ฅ ๊ฐ๋ฅ |
| RMSNorm | LayerNorm๋ณด๋ค ๋น ๋ฅด๊ณ ํจ๊ณผ์ |
| SwiGLU | GELU๋ณด๋ค ์ข์ ์ฑ๋ฅ |
| GQA | KV cache ๋ฉ๋ชจ๋ฆฌ 8๋ฐฐ ์ ์ฝ |
์ค๋ฌด ๊ถ์ฅ ์ฌํญ¶
- 7B/8B: ๋จ์ผ GPU (16GB+), ๋น ๋ฅธ ์คํ์ฉ
- 13B: 24GB GPU, ๊ท ํ ์กํ ์ ํ
- 70B: ์ฌ๋ฌ GPU, ์ต๊ณ ์ฑ๋ฅ ํ์ ์
- ์์ํ: 4-bit์ผ๋ก ๋ฉ๋ชจ๋ฆฌ 75% ์ ์ฝ
๋ค์ ๋จ๊ณ¶
- 09_Mistral_MoE.md: Mixture of Experts ์ํคํ ์ฒ
- 19_PEFT_Unified.md: LLaMA Fine-tuning (LoRA)
์ฐธ๊ณ ์๋ฃ¶
ํต์ฌ ๋ ผ๋ฌธ¶
- Touvron et al. (2023). "LLaMA: Open and Efficient Foundation Language Models"
- Touvron et al. (2023). "LLaMA 2: Open Foundation and Fine-Tuned Chat Models"
- Su et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding"
- Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models"