Mistral & Mixture of Experts
Mistral & Mixture of Experts¶
ํ์ต ๋ชฉํ¶
- Mistral 7B์ ์ํคํ ์ฒ ํน์ง ์ดํด
- Mixture of Experts (MoE) ๊ฐ๋ ๊ณผ ๋์ ์๋ฆฌ ํ์
- Mixtral 8x7B ๊ตฌ์กฐ ํ์ต
- Sparse MoE์ ์ฅ๋จ์ ๊ณผ ์ค๋ฌด ํ์ฉ๋ฒ ์ต๋
1. Mistral 7B ๊ฐ์¶
1.1 Mistral์ ํ์ ¶
Mistral 7B๋ 2023๋ Mistral AI๊ฐ ๊ณต๊ฐํ ๋ชจ๋ธ๋ก, 7B ํ๋ผ๋ฏธํฐ๋ก 13B ๊ธ ์ฑ๋ฅ์ ๋ฌ์ฑํ์ต๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Mistral 7B ํน์ง โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ ์ฑ๋ฅ ๋น๊ต (2023.10 ๊ธฐ์ค): โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Model โ Params โ MMLU โ HellaSwag โ GSM8K โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ LLaMA 2 7B โ 7B โ 45.3 โ 77.2 โ 14.6 โ โ
โ โ LLaMA 2 13B โ 13B โ 54.8 โ 80.7 โ 28.7 โ โ
โ โ Mistral 7B โ 7B โ 60.1 โ 81.3 โ 52.2 โ โ! โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ ํต์ฌ ๊ธฐ์ : โ
โ โข Sliding Window Attention (SWA) โ
โ โข Grouped Query Attention (GQA) โ
โ โข ๋ ๋ง์ ๋ฐ์ดํฐ๋ก Over-training โ
โ โข Flash Attention 2 ์ต์ ํ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1.2 Mistral ์ํคํ ์ฒ ์ฌ์¶
MISTRAL_CONFIGS = {
"mistral-7b": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8, # GQA
"head_dim": 128,
"hidden_dim": 14336,
"vocab_size": 32000,
"context_length": 32768, # ๊ธฐ์ ์ ํ๊ณ
"sliding_window": 4096, # Sliding Window Attention
"rope_theta": 10000.0,
},
}
# LLaMA 2 7B์ ๋น๊ต
LLAMA2_7B = {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 32, # MHA (GQA ๋ฏธ์ฌ์ฉ)
"hidden_dim": 11008,
"context_length": 4096,
"sliding_window": None, # ์ ์ฒด attention
}
2. Sliding Window Attention (SWA)¶
2.1 ๊ฐ๋ ¶
Sliding Window Attention์ ๊ฐ ํ ํฐ์ด ๊ณ ์ ๋ ์๋์ฐ ๋ด์ ํ ํฐ๋ง attendํ๋๋ก ์ ํํฉ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Sliding Window Attention โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Full Attention (๊ธฐ์กด): โ
โ โโโโโโโโโโโโโโโโโโโโโโโโ โ
โ ๋ชจ๋ ํ ํฐ์ด ๋ชจ๋ ์ด์ ํ ํฐ์ attend โ
โ ๋ณต์ก๋: O(nยฒ) โ
โ โ
โ Position: 1 2 3 4 5 6 7 8 9 10 โ
โ Token 10: โ โ โ โ โ โ โ โ โ โ โ
โ โ
โ Sliding Window (W=4): โ
โ โโโโโโโโโโโโโโโโโโโโโโโโ โ
โ ์๋์ฐ ํฌ๊ธฐ W ๋ด์ ํ ํฐ๋ง attend โ
โ ๋ณต์ก๋: O(n ร W) โ
โ โ
โ Position: 1 2 3 4 5 6 7 8 9 10 โ
โ Token 10: โ โ โ โ โ โ โ โ โ โ โ
โ โ โโโโโโโโโฌโโโโโโโโ โ
โ Window start Window (W=4) โ
โ โ
โ ๋ ์ด์ด ์๊ธฐ ํจ๊ณผ: โ
โ โโโโโโโโโโโโโโโโโโโโโโโโ โ
โ L๊ฐ ๋ ์ด์ด โ ์ค์ receptive field = L ร W โ
โ 32 layers ร 4096 window = 131,072 ํ ํฐ ๋ฒ์! โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
2.2 ๊ตฌํ¶
import torch
import torch.nn.functional as F
import math
def sliding_window_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
window_size: int = 4096,
causal: bool = True,
):
"""
Sliding Window Attention ๊ตฌํ
Args:
query: (batch, n_heads, seq_len, head_dim)
key: (batch, n_heads, seq_len, head_dim)
value: (batch, n_heads, seq_len, head_dim)
window_size: ์๋์ฐ ํฌ๊ธฐ
causal: Causal masking ์ ์ฉ ์ฌ๋ถ
"""
batch, n_heads, seq_len, head_dim = query.shape
scale = 1.0 / math.sqrt(head_dim)
# Attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) * scale
# Sliding window mask ์์ฑ
# ๊ฐ ์์น i๋ max(0, i-W+1)๋ถํฐ i๊น์ง๋ง attend
row_idx = torch.arange(seq_len).unsqueeze(1) # (seq, 1)
col_idx = torch.arange(seq_len).unsqueeze(0) # (1, seq)
# Causal: col <= row
# Window: col >= row - window_size + 1
if causal:
mask = (col_idx <= row_idx) & (col_idx >= row_idx - window_size + 1)
else:
mask = torch.abs(row_idx - col_idx) < window_size
# Mask ์ ์ฉ
mask = mask.to(scores.device)
scores = scores.masked_fill(~mask, float('-inf'))
# Softmax & output
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, value)
return output
# ๋ฉ๋ชจ๋ฆฌ ๋น๊ต
def compare_attention_memory(seq_len, window_size=4096):
"""Full vs Sliding Window ๋ฉ๋ชจ๋ฆฌ ๋น๊ต"""
full_attention_mem = seq_len * seq_len # O(nยฒ)
sliding_window_mem = seq_len * window_size # O(n ร W)
print(f"Sequence length: {seq_len:,}")
print(f"Full Attention: {full_attention_mem:,} elements")
print(f"Sliding Window: {sliding_window_mem:,} elements")
print(f"Memory savings: {(1 - sliding_window_mem/full_attention_mem)*100:.1f}%")
compare_attention_memory(32768, 4096)
# Sequence length: 32,768
# Full Attention: 1,073,741,824 elements
# Sliding Window: 134,217,728 elements
# Memory savings: 87.5%
2.3 Rolling Buffer KV Cache¶
"""
Rolling Buffer: ๊ณ ์ ํฌ๊ธฐ KV cache๋ก ๊ธด ์ํ์ค ์ฒ๋ฆฌ
์ผ๋ฐ KV Cache:
- ๋ชจ๋ ํ ํฐ์ KV ์ ์ฅ
- ๋ฉ๋ชจ๋ฆฌ: O(seq_len)
Rolling Buffer:
- window_size๋งํผ๋ง ์ ์ฅ
- ์ค๋๋ KV๋ ๋ฎ์ด์
- ๋ฉ๋ชจ๋ฆฌ: O(window_size) = ์์!
์์ (window=4):
Step 1: [K1, K2, K3, K4]
Step 2: [K5, K2, K3, K4] โ K1 ์์น์ K5 ์ ์ฅ
Step 3: [K5, K6, K3, K4] โ K2 ์์น์ K6 ์ ์ฅ
...
์ฅ์ :
- ๋ฌดํ ์ํ์ค ์ฒ๋ฆฌ ๊ฐ๋ฅ (๋ฉ๋ชจ๋ฆฌ ๊ณ ์ )
- ์ถ๋ก ์๋ ์ผ์
๋จ์ :
- ์ค๋๋ ์ ๋ณด ์์ค
- ๋ ์ด์ด ์๊ธฐ๋ก ๋ณด์
"""
class RollingKVCache:
def __init__(self, window_size: int, n_layers: int, n_kv_heads: int, head_dim: int):
self.window_size = window_size
self.cache_k = torch.zeros(n_layers, 1, window_size, n_kv_heads, head_dim)
self.cache_v = torch.zeros(n_layers, 1, window_size, n_kv_heads, head_dim)
self.pos = 0
def update(self, layer_idx: int, k: torch.Tensor, v: torch.Tensor):
"""์๋ก์ด KV๋ฅผ cache์ ์ถ๊ฐ (circular buffer)"""
seq_len = k.shape[1]
for i in range(seq_len):
idx = (self.pos + i) % self.window_size
self.cache_k[layer_idx, :, idx] = k[:, i]
self.cache_v[layer_idx, :, idx] = v[:, i]
self.pos = (self.pos + seq_len) % self.window_size
def get(self, layer_idx: int):
return self.cache_k[layer_idx], self.cache_v[layer_idx]
3. Mixture of Experts (MoE) ๊ธฐ์ด¶
3.1 MoE ๊ฐ๋ ¶
Mixture of Experts๋ ์ฌ๋ฌ "์ ๋ฌธ๊ฐ" ๋คํธ์ํฌ ์ค ์ผ๋ถ๋ง ํ์ฑํํ์ฌ ํจ์จ์ฑ์ ๋์ด๋ ์ํคํ ์ฒ์ ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Mixture of Experts ๊ฐ๋
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Dense Model: โ
โ โโโโโโโโโโโโโโโโโ โ
โ Input โโโบ [FFN (์ ์ฒด)] โโโบ Output โ
โ โข ๋ชจ๋ ํ๋ผ๋ฏธํฐ๊ฐ ๋งค๋ฒ ํ์ฑํ โ
โ โข ๊ณ์ฐ๋ = ํ๋ผ๋ฏธํฐ ์์ ๋น๋ก โ
โ โ
โ Sparse MoE Model: โ
โ โโโโโโโโโโโโโโโโโ โ
โ โโโโบ Expert 1 โโโ โ
โ โ โ โ
โ Input โโโบ Router โโโโโโผโโโบ Expert 2 โโโผโโโบ Combine โโโบ Output โ
โ โ โ โ โ
โ (Top-K ์ ํ) โโโโบ Expert 3 โโโ โ
โ โโโโบ Expert N (๋นํ์ฑํ) โ
โ โ
โ โข ๋ผ์ฐํฐ๊ฐ K๊ฐ ์ ๋ฌธ๊ฐ๋ง ์ ํ โ
โ โข ํ๋ผ๋ฏธํฐ ๅค, ๊ณ์ฐ๋ ๅฐ โ
โ โข ์: 8๊ฐ ์ ๋ฌธ๊ฐ, 2๊ฐ๋ง ํ์ฑํ โ ๊ณ์ฐ๋ 1/4 โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
3.2 Router (Gating Network)¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopKRouter(nn.Module):
"""
Top-K Router: ์
๋ ฅ๋ง๋ค K๊ฐ์ ์ ๋ฌธ๊ฐ ์ ํ
์์:
G(x) = softmax(TopK(x ยท W_g))
์ฌ๊ธฐ์ TopK๋ ์์ K๊ฐ๋ง ์ ์ง, ๋๋จธ์ง๋ -inf
"""
def __init__(self, dim: int, num_experts: int, top_k: int = 2):
super().__init__()
self.top_k = top_k
self.num_experts = num_experts
self.gate = nn.Linear(dim, num_experts, bias=False)
def forward(self, x):
"""
Args:
x: (batch, seq_len, dim)
Returns:
router_probs: (batch, seq_len, top_k) - ์ ํ๋ ์ ๋ฌธ๊ฐ ๊ฐ์ค์น
expert_indices: (batch, seq_len, top_k) - ์ ํ๋ ์ ๋ฌธ๊ฐ ์ธ๋ฑ์ค
"""
# ๋ผ์ฐํฐ ๋ก์ง ๊ณ์ฐ
logits = self.gate(x) # (batch, seq_len, num_experts)
# Top-K ์ ํ
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
# Softmax (์ ํ๋ ์ ๋ฌธ๊ฐ๋ค ์ฌ์ด์์)
router_probs = F.softmax(top_k_logits, dim=-1)
return router_probs, top_k_indices
# ์์
router = TopKRouter(dim=4096, num_experts=8, top_k=2)
x = torch.randn(2, 10, 4096) # batch=2, seq=10
probs, indices = router(x)
print(f"Router probs shape: {probs.shape}") # (2, 10, 2)
print(f"Expert indices shape: {indices.shape}") # (2, 10, 2)
print(f"Selected experts for token 0: {indices[0, 0]}") # e.g., tensor([3, 7])
3.3 Expert Layer¶
class MoELayer(nn.Module):
"""
Mixture of Experts Layer
๊ฐ ํ ํฐ์ด Top-K ์ ๋ฌธ๊ฐ์๊ฒ ๋ผ์ฐํ
๋์ด ์ฒ๋ฆฌ๋จ
"""
def __init__(
self,
dim: int,
hidden_dim: int,
num_experts: int = 8,
top_k: int = 2,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router
self.router = TopKRouter(dim, num_experts, top_k)
# Experts (๊ฐ๊ฐ ๋
๋ฆฝ์ ์ธ FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(dim, hidden_dim, bias=False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias=False)
)
for _ in range(num_experts)
])
def forward(self, x):
"""
Args:
x: (batch, seq_len, dim)
Returns:
output: (batch, seq_len, dim)
"""
batch, seq_len, dim = x.shape
# ๋ผ์ฐํ
router_probs, expert_indices = self.router(x)
# router_probs: (batch, seq_len, top_k)
# expert_indices: (batch, seq_len, top_k)
# ์ถ๋ ฅ ์ด๊ธฐํ
output = torch.zeros_like(x)
# ๊ฐ ์ ๋ฌธ๊ฐ๋ณ๋ก ์ฒ๋ฆฌ (๊ฐ๋จํ ๊ตฌํ, ์ค์ ๋ก๋ ๋ ์ต์ ํ๋จ)
for k in range(self.top_k):
expert_idx = expert_indices[:, :, k] # (batch, seq_len)
expert_prob = router_probs[:, :, k:k+1] # (batch, seq_len, 1)
for e in range(self.num_experts):
# ์ด ์ ๋ฌธ๊ฐ๊ฐ ์ ํ๋ ์์น ์ฐพ๊ธฐ
mask = (expert_idx == e)
if mask.any():
# ํด๋น ํ ํฐ๋ค ์ถ์ถ
selected = x[mask] # (num_selected, dim)
# ์ ๋ฌธ๊ฐ ์ ์ฉ
expert_output = self.experts[e](selected)
# ๊ฐ์ค์น ์ ์ฉํ์ฌ ๊ฒฐ๊ณผ์ ์ถ๊ฐ
output[mask] += expert_prob[mask].squeeze(-1).unsqueeze(-1) * expert_output
return output
# ์ฌ์ฉ ์์
moe = MoELayer(dim=4096, hidden_dim=14336, num_experts=8, top_k=2)
x = torch.randn(2, 10, 4096)
output = moe(x)
print(f"Output shape: {output.shape}") # (2, 10, 4096)
4. Mixtral 8x7B¶
4.1 ์ํคํ ์ฒ¶
Mixtral 8x7B๋ 8๊ฐ์ ์ ๋ฌธ๊ฐ๋ฅผ ๊ฐ์ง MoE ๋ชจ๋ธ๋ก, ๊ฐ ๋ ์ด์ด์์ 2๊ฐ์ ์ ๋ฌธ๊ฐ๋ง ํ์ฑํ๋ฉ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Mixtral 8x7B Architecture โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Transformer Block โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ Attention (GQA) โ โ โ
โ โ โ โข 32 query heads, 8 KV heads โ โ โ
โ โ โ โข Sliding Window (4096) โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ โ โ
โ โ โผ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ Sparse MoE FFN Layer โ โ โ
โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ
โ โ โ โ Router โ โ โ โ
โ โ โ โ (Select Top-2) โ โ โ โ
โ โ โ โโโโโโฌโโโโโฌโโโโโฌโโโโโฌโโโโโฌโโโโโฌโโโโโฌโโโโโฌโโ โ โ โ
โ โ โ โ โ โ โ โ โ โ โ โ โ โ
โ โ โ โผ โผ โผ โผ โผ โผ โผ โผ โ โ โ
โ โ โ [E1] [E2] [E3] [E4] [E5] [E6] [E7] [E8] โ โ โ
โ โ โ โ โ โ โ โ
โ โ โ ์ ํ ์ ํ ๋นํ์ฑ ๋นํ์ฑ ... โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ ์ด ํ๋ผ๋ฏธํฐ: ~46.7B (8 experts ร 7B FFN params) โ
โ ํ์ฑ ํ๋ผ๋ฏธํฐ: ~12.9B (2/8 experts) โ
โ ์ถ๋ก ์๋: 12.9B dense ๋ชจ๋ธ๊ณผ ์ ์ฌ โ
โ ์ฑ๋ฅ: 70B dense ๋ชจ๋ธ ์์ค โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
4.2 Mixtral ์ฌ์¶
MIXTRAL_CONFIG = {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8,
"head_dim": 128,
"hidden_dim": 14336,
"vocab_size": 32000,
# MoE ์ค์
"num_experts": 8,
"num_experts_per_tok": 2, # Top-K
# Attention
"sliding_window": 4096,
"context_length": 32768,
# ํ๋ผ๋ฏธํฐ ๊ณ์ฐ
# Attention: 4 ร dimยฒ ร n_layers = 4 ร 4096ยฒ ร 32 โ 2.1B
# MoE FFN: 8 ร 3 ร dim ร hidden ร n_layers = 8 ร 3 ร 4096 ร 14336 ร 32 โ 44.6B
# Total: ~46.7B
# Active: ~12.9B (attention + 2/8 FFN)
}
4.3 Load Balancing Loss¶
MoE์ ํต์ฌ ๊ณผ์ ์ค ํ๋๋ ์ ๋ฌธ๊ฐ ๋ถ๊ท ํ ๋ฌธ์ ์ ๋๋ค.
def load_balancing_loss(router_probs, expert_indices, num_experts):
"""
Load Balancing Loss: ์ ๋ฌธ๊ฐ๋ค์ด ๊ท ๋ฑํ๊ฒ ์ฌ์ฉ๋๋๋ก ์ ๋
๋ฌธ์ : ์ผ๋ถ ์ ๋ฌธ๊ฐ๋ง ๊ณผ๋ํ๊ฒ ์ฌ์ฉ๋๋ ํ์ (winner-take-all)
ํด๊ฒฐ: ๊ท ํ ์กํ ๋ผ์ฐํ
์ ์ ๋ํ๋ auxiliary loss
์์:
L_balance = ฮฑ ร ฮฃ_e (f_e ร P_e)
f_e = ์ ๋ฌธ๊ฐ e๊ฐ ์ ํ๋ ํ ํฐ ๋น์จ
P_e = ์ ๋ฌธ๊ฐ e์ ํ ๋น๋ ๋ผ์ฐํ
ํ๋ฅ ํ๊ท
ฮฑ = ์ค์ผ์ผ๋ง ๊ณ์ (์: 0.01)
"""
batch, seq_len, top_k = router_probs.shape
num_tokens = batch * seq_len
# f_e: ๊ฐ ์ ๋ฌธ๊ฐ๊ฐ ์ ํ๋ ๋น์จ
expert_counts = torch.zeros(num_experts, device=router_probs.device)
for e in range(num_experts):
expert_counts[e] = (expert_indices == e).float().sum() / (num_tokens * top_k)
# P_e: ๊ฐ ์ ๋ฌธ๊ฐ์ ํ ๋น๋ ํ๊ท ํ๋ฅ
expert_probs = torch.zeros(num_experts, device=router_probs.device)
# (๊ฐ์ํ๋ ๊ณ์ฐ - ์ค์ ๋ก๋ gate logits์์ ๊ณ์ฐ)
# Balance loss
loss = (expert_counts * expert_probs).sum() * num_experts
return loss
# ํ์ต ์
"""
total_loss = language_model_loss + alpha * load_balancing_loss
"""
5. MoE์ ์ฅ๋จ์ ¶
5.1 ์ฅ์ ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ MoE์ ์ฅ์ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 1. ํ๋ผ๋ฏธํฐ ํจ์จ์ฑ โ
โ โข ๋ง์ ํ๋ผ๋ฏธํฐ, ์ ์ ๊ณ์ฐ๋ โ
โ โข Mixtral 8x7B: 46.7B params, 12.9B active โ
โ โข Dense 70B ๊ธ ์ฑ๋ฅ, 13B ๊ธ ์๋ โ
โ โ
โ 2. ์ ๋ฌธํ (Specialization) โ
โ โข ๊ฐ ์ ๋ฌธ๊ฐ๊ฐ ๋ค๋ฅธ ํจํด/๋๋ฉ์ธ ํ์ต โ
โ โข ์: Expert 1=์ํ, Expert 2=์ฝ๋, Expert 3=์ธ์ด โ
โ โข ๋ ๊น์ ์ ๋ฌธ ์ง์ ์ธ์ฝ๋ฉ ๊ฐ๋ฅ โ
โ โ
โ 3. ์ค์ผ์ผ๋ง โ
โ โข ์ ๋ฌธ๊ฐ ์ ๋๋ ค ๋ชจ๋ธ ํ์ฅ ์ฉ์ด โ
โ โข ๊ณ์ฐ๋ ์ฆ๊ฐ ์ต์ํํ๋ฉฐ ์ฉ๋ ์ฆ๊ฐ โ
โ โข Google Switch Transformer: 1.6T params! โ
โ โ
โ 4. ํ์ต ํจ์จ โ
โ โข ๊ฐ์ ๊ณ์ฐ๋์ผ๋ก ๋ ํฐ ๋ชจ๋ธ ํ์ต ๊ฐ๋ฅ โ
โ โข Scaling Law ๊ด์ ์์ ์ ๋ฆฌ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
5.2 ๋จ์ ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ MoE์ ๋จ์ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 1. ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋ โ
โ โข ๋ชจ๋ ์ ๋ฌธ๊ฐ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํด์ผ ํจ โ
โ โข Mixtral 8x7B: 46.7B params โ 93GB (FP16) โ
โ โข ์ถ๋ก ์ ๋ง์ GPU ๋ฉ๋ชจ๋ฆฌ ํ์ โ
โ โ
โ 2. ํ์ต ๋ถ์์ ์ฑ โ
โ โข ๋ผ์ฐํฐ ํ์ต์ด ์ด๋ ค์ โ
โ โข ์ ๋ฌธ๊ฐ ๋ถ๊ท ํ (์ผ๋ถ๋ง ์ฌ์ฉ) โ
โ โข Auxiliary loss ํ๋ ํ์ โ
โ โ
โ 3. ๋ถ์ฐ ํ์ต ๋ณต์ก์ฑ โ
โ โข Expert parallelism ํ์ โ
โ โข ํต์ ์ค๋ฒํค๋ โ
โ โข ๋ก๋ ๋ฐธ๋ฐ์ฑ ์ด๋ ค์ โ
โ โ
โ 4. Fine-tuning ์ด๋ ค์ โ
โ โข ์ ๋ฌธ๊ฐ specialization ์ ์งํ๋ฉฐ ์ ์ ํ์ โ
โ โข ์ผ๋ถ ์ ๋ฌธ๊ฐ๋ง fine-tune? โ
โ โข ์ฐ๊ตฌ ์งํ ์ค โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
6. Mistral/Mixtral ์ค์ต¶
6.1 Mistral 7B ์ฌ์ฉ¶
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Mistral 7B ๋ก๋
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# ํ
์คํธ ์์ฑ
prompt = "[INST] Explain the concept of machine learning in simple terms. [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True,
top_p=0.9,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
6.2 Mixtral 8x7B ์ฌ์ฉ¶
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Mixtral 8x7B (๋ง์ ๋ฉ๋ชจ๋ฆฌ ํ์!)
model_name = "mistralai/Mixtral-8x7B-v0.1"
# 4-bit ์์ํ๋ก ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
# ์ฌ์ฉ
prompt = "[INST] Write a Python function to calculate fibonacci numbers. [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
6.3 vLLM์ผ๋ก ํจ์จ์ ์๋น¶
from vllm import LLM, SamplingParams
# vLLM์ MoE ๋ชจ๋ธ์ ํจ์จ์ ์ผ๋ก ์๋น
llm = LLM(
model="mistralai/Mixtral-8x7B-v0.1",
tensor_parallel_size=2, # 2 GPU
dtype="float16",
)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=200,
)
prompts = [
"[INST] What is machine learning? [/INST]",
"[INST] Explain quantum computing. [/INST]",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt}")
print(f"Response: {output.outputs[0].text}")
print("-" * 50)
7. MoE ๋ณํ๋ค¶
7.1 ์ฃผ์ MoE ๋ชจ๋ธ๋ค¶
| ๋ชจ๋ธ | ์กฐ์ง | ์ ๋ฌธ๊ฐ ์ | Top-K | ์ด ํ๋ผ๋ฏธํฐ | ํ์ฑ ํ๋ผ๋ฏธํฐ |
|---|---|---|---|---|---|
| Switch Transformer | 2048 | 1 | 1.6T | <1B | |
| GLaM | 64 | 2 | 1.2T | ~100B | |
| Mixtral 8x7B | Mistral | 8 | 2 | 46.7B | 12.9B |
| Mixtral 8x22B | Mistral | 8 | 2 | 141B | 39B |
| DeepSeek MoE | DeepSeek | 64 | 6 | 145B | 22B |
7.2 Fine-grained MoE¶
"""
Fine-grained MoE: ๋ ๋ง์ ์์ ์ ๋ฌธ๊ฐ
๊ธฐ์กด (Coarse-grained):
- 8๊ฐ ํฐ ์ ๋ฌธ๊ฐ, Top-2 ์ ํ
- ๊ฐ ์ ๋ฌธ๊ฐ๊ฐ ๋์ ๋ฒ์ ๋ด๋น
Fine-grained (DeepSeek ์คํ์ผ):
- 64๊ฐ ์์ ์ ๋ฌธ๊ฐ, Top-6 ์ ํ
- ๋ ์ธ๋ฐํ ์ ๋ฌธํ ๊ฐ๋ฅ
- ๋ผ์ฐํ
์ ์ฐ์ฑ ์ฆ๊ฐ
์ฅ์ :
- ๋ ์ธ๋ฐํ ์ ๋ฌธํ
- ๋ ๋์ ๋ก๋ ๋ฐธ๋ฐ์ฑ
- ํ์ฅ์ฑ
๋จ์ :
- ๋ผ์ฐํ
์ค๋ฒํค๋
- ํ์ต ๋ณต์ก์ฑ
"""
์ ๋ฆฌ¶
Mistral ํต์ฌ¶
- Sliding Window Attention: ๋ฉ๋ชจ๋ฆฌ O(W)๋ก ๊ธด ์ํ์ค ์ฒ๋ฆฌ
- GQA: KV cache ํจ์จ์ฑ
- Over-training: ์์ ๋ชจ๋ธ, ๋ง์ ๋ฐ์ดํฐ
MoE ํต์ฌ¶
- Sparse Activation: ํ๋ผ๋ฏธํฐ ๅค, ๊ณ์ฐ ๅฐ
- Router: Top-K ์ ๋ฌธ๊ฐ ์ ํ
- Load Balancing: ์ ๋ฌธ๊ฐ ๊ท ํ ์ ์ง
์ค๋ฌด ์ ํ ๊ฐ์ด๋¶
| ์ํฉ | ๊ถ์ฅ ๋ชจ๋ธ |
|---|---|
| ๋จ์ผ GPU (16GB) | Mistral 7B (4-bit) |
| 2ร GPU (48GB) | Mixtral 8x7B (4-bit) |
| ์๋ฒ๊ธ (8ร A100) | Mixtral 8x22B |
| ์๋ ์ฐ์ | Mistral 7B |
| ์ฑ๋ฅ ์ฐ์ | Mixtral 8x7B+ |
๋ค์ ๋จ๊ณ¶
- 10_Long_Context_Models.md: ๊ธด ์ปจํ ์คํธ ์ฒ๋ฆฌ
- 22_Inference_Optimization.md: ํจ์จ์ ์ถ๋ก
์ฐธ๊ณ ์๋ฃ¶
ํต์ฌ ๋ ผ๋ฌธ¶
- Jiang et al. (2023). "Mistral 7B"
- Jiang et al. (2024). "Mixtral of Experts"
- Fedus et al. (2022). "Switch Transformers: Scaling to Trillion Parameter Models"
- Du et al. (2022). "GLaM: Efficient Scaling of Language Models"