Mistral & Mixture of Experts

Mistral & Mixture of Experts

ํ•™์Šต ๋ชฉํ‘œ

  • Mistral 7B์˜ ์•„ํ‚คํ…์ฒ˜ ํŠน์ง• ์ดํ•ด
  • Mixture of Experts (MoE) ๊ฐœ๋…๊ณผ ๋™์ž‘ ์›๋ฆฌ ํŒŒ์•…
  • Mixtral 8x7B ๊ตฌ์กฐ ํ•™์Šต
  • Sparse MoE์˜ ์žฅ๋‹จ์ ๊ณผ ์‹ค๋ฌด ํ™œ์šฉ๋ฒ• ์Šต๋“

1. Mistral 7B ๊ฐœ์š”

1.1 Mistral์˜ ํ˜์‹ 

Mistral 7B๋Š” 2023๋…„ Mistral AI๊ฐ€ ๊ณต๊ฐœํ•œ ๋ชจ๋ธ๋กœ, 7B ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ 13B ๊ธ‰ ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Mistral 7B ํŠน์ง•                               โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  ์„ฑ๋Šฅ ๋น„๊ต (2023.10 ๊ธฐ์ค€):                                        โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”      โ”‚
โ”‚  โ”‚  Model          โ”‚ Params โ”‚ MMLU  โ”‚ HellaSwag โ”‚ GSM8K โ”‚      โ”‚
โ”‚  โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”‚      โ”‚
โ”‚  โ”‚  LLaMA 2 7B     โ”‚ 7B     โ”‚ 45.3  โ”‚ 77.2      โ”‚ 14.6  โ”‚      โ”‚
โ”‚  โ”‚  LLaMA 2 13B    โ”‚ 13B    โ”‚ 54.8  โ”‚ 80.7      โ”‚ 28.7  โ”‚      โ”‚
โ”‚  โ”‚  Mistral 7B     โ”‚ 7B     โ”‚ 60.1  โ”‚ 81.3      โ”‚ 52.2  โ”‚ โ†!   โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜      โ”‚
โ”‚                                                                 โ”‚
โ”‚  ํ•ต์‹ฌ ๊ธฐ์ˆ :                                                       โ”‚
โ”‚  โ€ข Sliding Window Attention (SWA)                               โ”‚
โ”‚  โ€ข Grouped Query Attention (GQA)                                โ”‚
โ”‚  โ€ข ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ๋กœ Over-training                                 โ”‚
โ”‚  โ€ข Flash Attention 2 ์ตœ์ ํ™”                                      โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

1.2 Mistral ์•„ํ‚คํ…์ฒ˜ ์‚ฌ์–‘

MISTRAL_CONFIGS = {
    "mistral-7b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 8,           # GQA
        "head_dim": 128,
        "hidden_dim": 14336,
        "vocab_size": 32000,
        "context_length": 32768,   # ๊ธฐ์ˆ ์  ํ•œ๊ณ„
        "sliding_window": 4096,    # Sliding Window Attention
        "rope_theta": 10000.0,
    },
}

# LLaMA 2 7B์™€ ๋น„๊ต
LLAMA2_7B = {
    "dim": 4096,
    "n_layers": 32,
    "n_heads": 32,
    "n_kv_heads": 32,              # MHA (GQA ๋ฏธ์‚ฌ์šฉ)
    "hidden_dim": 11008,
    "context_length": 4096,
    "sliding_window": None,        # ์ „์ฒด attention
}

2. Sliding Window Attention (SWA)

2.1 ๊ฐœ๋…

Sliding Window Attention์€ ๊ฐ ํ† ํฐ์ด ๊ณ ์ •๋œ ์œˆ๋„์šฐ ๋‚ด์˜ ํ† ํฐ๋งŒ attendํ•˜๋„๋ก ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Sliding Window Attention                      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  Full Attention (๊ธฐ์กด):                                          โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                                       โ”‚
โ”‚  ๋ชจ๋“  ํ† ํฐ์ด ๋ชจ๋“  ์ด์ „ ํ† ํฐ์— attend                               โ”‚
โ”‚  ๋ณต์žก๋„: O(nยฒ)                                                   โ”‚
โ”‚                                                                 โ”‚
โ”‚  Position:  1  2  3  4  5  6  7  8  9  10                       โ”‚
โ”‚  Token 10:  โœ“  โœ“  โœ“  โœ“  โœ“  โœ“  โœ“  โœ“  โœ“  โœ“                       โ”‚
โ”‚                                                                 โ”‚
โ”‚  Sliding Window (W=4):                                          โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                                       โ”‚
โ”‚  ์œˆ๋„์šฐ ํฌ๊ธฐ W ๋‚ด์˜ ํ† ํฐ๋งŒ attend                                  โ”‚
โ”‚  ๋ณต์žก๋„: O(n ร— W)                                                โ”‚
โ”‚                                                                 โ”‚
โ”‚  Position:  1  2  3  4  5  6  7  8  9  10                       โ”‚
โ”‚  Token 10:  โœ—  โœ—  โœ—  โœ—  โœ—  โœ—  โœ“  โœ“  โœ“  โœ“                       โ”‚
โ”‚                         โ†‘     โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                 โ”‚
โ”‚                    Window start       Window (W=4)              โ”‚
โ”‚                                                                 โ”‚
โ”‚  ๋ ˆ์ด์–ด ์Œ“๊ธฐ ํšจ๊ณผ:                                                โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                                       โ”‚
โ”‚  L๊ฐœ ๋ ˆ์ด์–ด โ†’ ์‹ค์ œ receptive field = L ร— W                       โ”‚
โ”‚  32 layers ร— 4096 window = 131,072 ํ† ํฐ ๋ฒ”์œ„!                    โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2.2 ๊ตฌํ˜„

import torch
import torch.nn.functional as F
import math

def sliding_window_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    window_size: int = 4096,
    causal: bool = True,
):
    """
    Sliding Window Attention ๊ตฌํ˜„

    Args:
        query: (batch, n_heads, seq_len, head_dim)
        key: (batch, n_heads, seq_len, head_dim)
        value: (batch, n_heads, seq_len, head_dim)
        window_size: ์œˆ๋„์šฐ ํฌ๊ธฐ
        causal: Causal masking ์ ์šฉ ์—ฌ๋ถ€
    """
    batch, n_heads, seq_len, head_dim = query.shape
    scale = 1.0 / math.sqrt(head_dim)

    # Attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) * scale

    # Sliding window mask ์ƒ์„ฑ
    # ๊ฐ ์œ„์น˜ i๋Š” max(0, i-W+1)๋ถ€ํ„ฐ i๊นŒ์ง€๋งŒ attend
    row_idx = torch.arange(seq_len).unsqueeze(1)  # (seq, 1)
    col_idx = torch.arange(seq_len).unsqueeze(0)  # (1, seq)

    # Causal: col <= row
    # Window: col >= row - window_size + 1
    if causal:
        mask = (col_idx <= row_idx) & (col_idx >= row_idx - window_size + 1)
    else:
        mask = torch.abs(row_idx - col_idx) < window_size

    # Mask ์ ์šฉ
    mask = mask.to(scores.device)
    scores = scores.masked_fill(~mask, float('-inf'))

    # Softmax & output
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, value)

    return output

# ๋ฉ”๋ชจ๋ฆฌ ๋น„๊ต
def compare_attention_memory(seq_len, window_size=4096):
    """Full vs Sliding Window ๋ฉ”๋ชจ๋ฆฌ ๋น„๊ต"""
    full_attention_mem = seq_len * seq_len  # O(nยฒ)
    sliding_window_mem = seq_len * window_size  # O(n ร— W)

    print(f"Sequence length: {seq_len:,}")
    print(f"Full Attention: {full_attention_mem:,} elements")
    print(f"Sliding Window: {sliding_window_mem:,} elements")
    print(f"Memory savings: {(1 - sliding_window_mem/full_attention_mem)*100:.1f}%")

compare_attention_memory(32768, 4096)
# Sequence length: 32,768
# Full Attention: 1,073,741,824 elements
# Sliding Window: 134,217,728 elements
# Memory savings: 87.5%

2.3 Rolling Buffer KV Cache

"""
Rolling Buffer: ๊ณ ์ • ํฌ๊ธฐ KV cache๋กœ ๊ธด ์‹œํ€€์Šค ์ฒ˜๋ฆฌ

์ผ๋ฐ˜ KV Cache:
- ๋ชจ๋“  ํ† ํฐ์˜ KV ์ €์žฅ
- ๋ฉ”๋ชจ๋ฆฌ: O(seq_len)

Rolling Buffer:
- window_size๋งŒํผ๋งŒ ์ €์žฅ
- ์˜ค๋ž˜๋œ KV๋Š” ๋ฎ์–ด์”€
- ๋ฉ”๋ชจ๋ฆฌ: O(window_size) = ์ƒ์ˆ˜!

์˜ˆ์‹œ (window=4):
Step 1: [K1, K2, K3, K4]
Step 2: [K5, K2, K3, K4]  โ† K1 ์œ„์น˜์— K5 ์ €์žฅ
Step 3: [K5, K6, K3, K4]  โ† K2 ์œ„์น˜์— K6 ์ €์žฅ
...

์žฅ์ :
- ๋ฌดํ•œ ์‹œํ€€์Šค ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅ (๋ฉ”๋ชจ๋ฆฌ ๊ณ ์ •)
- ์ถ”๋ก  ์†๋„ ์ผ์ •

๋‹จ์ :
- ์˜ค๋ž˜๋œ ์ •๋ณด ์†์‹ค
- ๋ ˆ์ด์–ด ์Œ“๊ธฐ๋กœ ๋ณด์™„
"""

class RollingKVCache:
    def __init__(self, window_size: int, n_layers: int, n_kv_heads: int, head_dim: int):
        self.window_size = window_size
        self.cache_k = torch.zeros(n_layers, 1, window_size, n_kv_heads, head_dim)
        self.cache_v = torch.zeros(n_layers, 1, window_size, n_kv_heads, head_dim)
        self.pos = 0

    def update(self, layer_idx: int, k: torch.Tensor, v: torch.Tensor):
        """์ƒˆ๋กœ์šด KV๋ฅผ cache์— ์ถ”๊ฐ€ (circular buffer)"""
        seq_len = k.shape[1]
        for i in range(seq_len):
            idx = (self.pos + i) % self.window_size
            self.cache_k[layer_idx, :, idx] = k[:, i]
            self.cache_v[layer_idx, :, idx] = v[:, i]
        self.pos = (self.pos + seq_len) % self.window_size

    def get(self, layer_idx: int):
        return self.cache_k[layer_idx], self.cache_v[layer_idx]

3. Mixture of Experts (MoE) ๊ธฐ์ดˆ

3.1 MoE ๊ฐœ๋…

Mixture of Experts๋Š” ์—ฌ๋Ÿฌ "์ „๋ฌธ๊ฐ€" ๋„คํŠธ์›Œํฌ ์ค‘ ์ผ๋ถ€๋งŒ ํ™œ์„ฑํ™”ํ•˜์—ฌ ํšจ์œจ์„ฑ์„ ๋†’์ด๋Š” ์•„ํ‚คํ…์ฒ˜์ž…๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Mixture of Experts ๊ฐœ๋…                       โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  Dense Model:                                                   โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                                              โ”‚
โ”‚  Input โ”€โ”€โ–บ [FFN (์ „์ฒด)] โ”€โ”€โ–บ Output                              โ”‚
โ”‚  โ€ข ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋งค๋ฒˆ ํ™œ์„ฑํ™”                                     โ”‚
โ”‚  โ€ข ๊ณ„์‚ฐ๋Ÿ‰ = ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜์— ๋น„๋ก€                                    โ”‚
โ”‚                                                                 โ”‚
โ”‚  Sparse MoE Model:                                              โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                                              โ”‚
โ”‚                        โ”Œโ”€โ”€โ–บ Expert 1 โ”€โ”€โ”                        โ”‚
โ”‚                        โ”‚               โ”‚                        โ”‚
โ”‚  Input โ”€โ”€โ–บ Router โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ–บ Expert 2 โ”€โ”€โ”ผโ”€โ”€โ–บ Combine โ”€โ”€โ–บ Output  โ”‚
โ”‚              โ†“         โ”‚               โ”‚                        โ”‚
โ”‚         (Top-K ์„ ํƒ)   โ””โ”€โ”€โ–บ Expert 3 โ”€โ”€โ”˜                        โ”‚
โ”‚                        โ””โ”€โ”€โ–บ Expert N (๋น„ํ™œ์„ฑํ™”)                   โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ€ข ๋ผ์šฐํ„ฐ๊ฐ€ K๊ฐœ ์ „๋ฌธ๊ฐ€๋งŒ ์„ ํƒ                                      โ”‚
โ”‚  โ€ข ํŒŒ๋ผ๋ฏธํ„ฐ ๅคš, ๊ณ„์‚ฐ๋Ÿ‰ ๅฐ‘                                         โ”‚
โ”‚  โ€ข ์˜ˆ: 8๊ฐœ ์ „๋ฌธ๊ฐ€, 2๊ฐœ๋งŒ ํ™œ์„ฑํ™” โ†’ ๊ณ„์‚ฐ๋Ÿ‰ 1/4                       โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

3.2 Router (Gating Network)

import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKRouter(nn.Module):
    """
    Top-K Router: ์ž…๋ ฅ๋งˆ๋‹ค K๊ฐœ์˜ ์ „๋ฌธ๊ฐ€ ์„ ํƒ

    ์ˆ˜์‹:
    G(x) = softmax(TopK(x ยท W_g))

    ์—ฌ๊ธฐ์„œ TopK๋Š” ์ƒ์œ„ K๊ฐœ๋งŒ ์œ ์ง€, ๋‚˜๋จธ์ง€๋Š” -inf
    """
    def __init__(self, dim: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.top_k = top_k
        self.num_experts = num_experts
        self.gate = nn.Linear(dim, num_experts, bias=False)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, dim)

        Returns:
            router_probs: (batch, seq_len, top_k) - ์„ ํƒ๋œ ์ „๋ฌธ๊ฐ€ ๊ฐ€์ค‘์น˜
            expert_indices: (batch, seq_len, top_k) - ์„ ํƒ๋œ ์ „๋ฌธ๊ฐ€ ์ธ๋ฑ์Šค
        """
        # ๋ผ์šฐํ„ฐ ๋กœ์ง“ ๊ณ„์‚ฐ
        logits = self.gate(x)  # (batch, seq_len, num_experts)

        # Top-K ์„ ํƒ
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)

        # Softmax (์„ ํƒ๋œ ์ „๋ฌธ๊ฐ€๋“ค ์‚ฌ์ด์—์„œ)
        router_probs = F.softmax(top_k_logits, dim=-1)

        return router_probs, top_k_indices

# ์˜ˆ์‹œ
router = TopKRouter(dim=4096, num_experts=8, top_k=2)
x = torch.randn(2, 10, 4096)  # batch=2, seq=10
probs, indices = router(x)
print(f"Router probs shape: {probs.shape}")    # (2, 10, 2)
print(f"Expert indices shape: {indices.shape}")  # (2, 10, 2)
print(f"Selected experts for token 0: {indices[0, 0]}")  # e.g., tensor([3, 7])

3.3 Expert Layer

class MoELayer(nn.Module):
    """
    Mixture of Experts Layer

    ๊ฐ ํ† ํฐ์ด Top-K ์ „๋ฌธ๊ฐ€์—๊ฒŒ ๋ผ์šฐํŒ…๋˜์–ด ์ฒ˜๋ฆฌ๋จ
    """
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        num_experts: int = 8,
        top_k: int = 2,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router
        self.router = TopKRouter(dim, num_experts, top_k)

        # Experts (๊ฐ๊ฐ ๋…๋ฆฝ์ ์ธ FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, hidden_dim, bias=False),
                nn.SiLU(),
                nn.Linear(hidden_dim, dim, bias=False)
            )
            for _ in range(num_experts)
        ])

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, dim)

        Returns:
            output: (batch, seq_len, dim)
        """
        batch, seq_len, dim = x.shape

        # ๋ผ์šฐํŒ…
        router_probs, expert_indices = self.router(x)
        # router_probs: (batch, seq_len, top_k)
        # expert_indices: (batch, seq_len, top_k)

        # ์ถœ๋ ฅ ์ดˆ๊ธฐํ™”
        output = torch.zeros_like(x)

        # ๊ฐ ์ „๋ฌธ๊ฐ€๋ณ„๋กœ ์ฒ˜๋ฆฌ (๊ฐ„๋‹จํ•œ ๊ตฌํ˜„, ์‹ค์ œ๋กœ๋Š” ๋” ์ตœ์ ํ™”๋จ)
        for k in range(self.top_k):
            expert_idx = expert_indices[:, :, k]  # (batch, seq_len)
            expert_prob = router_probs[:, :, k:k+1]  # (batch, seq_len, 1)

            for e in range(self.num_experts):
                # ์ด ์ „๋ฌธ๊ฐ€๊ฐ€ ์„ ํƒ๋œ ์œ„์น˜ ์ฐพ๊ธฐ
                mask = (expert_idx == e)
                if mask.any():
                    # ํ•ด๋‹น ํ† ํฐ๋“ค ์ถ”์ถœ
                    selected = x[mask]  # (num_selected, dim)
                    # ์ „๋ฌธ๊ฐ€ ์ ์šฉ
                    expert_output = self.experts[e](selected)
                    # ๊ฐ€์ค‘์น˜ ์ ์šฉํ•˜์—ฌ ๊ฒฐ๊ณผ์— ์ถ”๊ฐ€
                    output[mask] += expert_prob[mask].squeeze(-1).unsqueeze(-1) * expert_output

        return output

# ์‚ฌ์šฉ ์˜ˆ์‹œ
moe = MoELayer(dim=4096, hidden_dim=14336, num_experts=8, top_k=2)
x = torch.randn(2, 10, 4096)
output = moe(x)
print(f"Output shape: {output.shape}")  # (2, 10, 4096)

4. Mixtral 8x7B

4.1 ์•„ํ‚คํ…์ฒ˜

Mixtral 8x7B๋Š” 8๊ฐœ์˜ ์ „๋ฌธ๊ฐ€๋ฅผ ๊ฐ€์ง„ MoE ๋ชจ๋ธ๋กœ, ๊ฐ ๋ ˆ์ด์–ด์—์„œ 2๊ฐœ์˜ ์ „๋ฌธ๊ฐ€๋งŒ ํ™œ์„ฑํ™”๋ฉ๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Mixtral 8x7B Architecture                     โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”‚
โ”‚  โ”‚                   Transformer Block                      โ”‚    โ”‚
โ”‚  โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚              Attention (GQA)                     โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚  โ€ข 32 query heads, 8 KV heads                   โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚  โ€ข Sliding Window (4096)                        โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ”‚    โ”‚
โ”‚  โ”‚                        โ”‚                                โ”‚    โ”‚
โ”‚  โ”‚                        โ–ผ                                โ”‚    โ”‚
โ”‚  โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚         Sparse MoE FFN Layer                    โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚  โ”‚              Router                      โ”‚    โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚  โ”‚         (Select Top-2)                   โ”‚    โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚  โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”˜    โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚       โ”‚    โ”‚    โ”‚    โ”‚    โ”‚    โ”‚    โ”‚    โ”‚      โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚       โ–ผ    โ–ผ    โ–ผ    โ–ผ    โ–ผ    โ–ผ    โ–ผ    โ–ผ      โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚     [E1] [E2] [E3] [E4] [E5] [E6] [E7] [E8]     โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚      โœ“         โœ“                               โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ”‚     ์„ ํƒ      ์„ ํƒ    ๋น„ํ™œ์„ฑ   ๋น„ํ™œ์„ฑ   ...        โ”‚    โ”‚    โ”‚
โ”‚  โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ”‚    โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ”‚
โ”‚                                                                 โ”‚
โ”‚  ์ด ํŒŒ๋ผ๋ฏธํ„ฐ: ~46.7B (8 experts ร— 7B FFN params)                 โ”‚
โ”‚  ํ™œ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ: ~12.9B (2/8 experts)                              โ”‚
โ”‚  ์ถ”๋ก  ์†๋„: 12.9B dense ๋ชจ๋ธ๊ณผ ์œ ์‚ฌ                               โ”‚
โ”‚  ์„ฑ๋Šฅ: 70B dense ๋ชจ๋ธ ์ˆ˜์ค€                                       โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

4.2 Mixtral ์‚ฌ์–‘

MIXTRAL_CONFIG = {
    "dim": 4096,
    "n_layers": 32,
    "n_heads": 32,
    "n_kv_heads": 8,
    "head_dim": 128,
    "hidden_dim": 14336,
    "vocab_size": 32000,

    # MoE ์„ค์ •
    "num_experts": 8,
    "num_experts_per_tok": 2,  # Top-K

    # Attention
    "sliding_window": 4096,
    "context_length": 32768,

    # ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณ„์‚ฐ
    # Attention: 4 ร— dimยฒ ร— n_layers = 4 ร— 4096ยฒ ร— 32 โ‰ˆ 2.1B
    # MoE FFN: 8 ร— 3 ร— dim ร— hidden ร— n_layers = 8 ร— 3 ร— 4096 ร— 14336 ร— 32 โ‰ˆ 44.6B
    # Total: ~46.7B
    # Active: ~12.9B (attention + 2/8 FFN)
}

4.3 Load Balancing Loss

MoE์˜ ํ•ต์‹ฌ ๊ณผ์ œ ์ค‘ ํ•˜๋‚˜๋Š” ์ „๋ฌธ๊ฐ€ ๋ถˆ๊ท ํ˜• ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค.

def load_balancing_loss(router_probs, expert_indices, num_experts):
    """
    Load Balancing Loss: ์ „๋ฌธ๊ฐ€๋“ค์ด ๊ท ๋“ฑํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜๋„๋ก ์œ ๋„

    ๋ฌธ์ œ: ์ผ๋ถ€ ์ „๋ฌธ๊ฐ€๋งŒ ๊ณผ๋„ํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜๋Š” ํ˜„์ƒ (winner-take-all)
    ํ•ด๊ฒฐ: ๊ท ํ˜• ์žกํžŒ ๋ผ์šฐํŒ…์„ ์œ ๋„ํ•˜๋Š” auxiliary loss

    ์ˆ˜์‹:
    L_balance = ฮฑ ร— ฮฃ_e (f_e ร— P_e)

    f_e = ์ „๋ฌธ๊ฐ€ e๊ฐ€ ์„ ํƒ๋œ ํ† ํฐ ๋น„์œจ
    P_e = ์ „๋ฌธ๊ฐ€ e์— ํ• ๋‹น๋œ ๋ผ์šฐํŒ… ํ™•๋ฅ  ํ‰๊ท 
    ฮฑ = ์Šค์ผ€์ผ๋ง ๊ณ„์ˆ˜ (์˜ˆ: 0.01)
    """
    batch, seq_len, top_k = router_probs.shape
    num_tokens = batch * seq_len

    # f_e: ๊ฐ ์ „๋ฌธ๊ฐ€๊ฐ€ ์„ ํƒ๋œ ๋น„์œจ
    expert_counts = torch.zeros(num_experts, device=router_probs.device)
    for e in range(num_experts):
        expert_counts[e] = (expert_indices == e).float().sum() / (num_tokens * top_k)

    # P_e: ๊ฐ ์ „๋ฌธ๊ฐ€์— ํ• ๋‹น๋œ ํ‰๊ท  ํ™•๋ฅ 
    expert_probs = torch.zeros(num_experts, device=router_probs.device)
    # (๊ฐ„์†Œํ™”๋œ ๊ณ„์‚ฐ - ์‹ค์ œ๋กœ๋Š” gate logits์—์„œ ๊ณ„์‚ฐ)

    # Balance loss
    loss = (expert_counts * expert_probs).sum() * num_experts

    return loss

# ํ•™์Šต ์‹œ
"""
total_loss = language_model_loss + alpha * load_balancing_loss
"""

5. MoE์˜ ์žฅ๋‹จ์ 

5.1 ์žฅ์ 

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    MoE์˜ ์žฅ์                                      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  1. ํŒŒ๋ผ๋ฏธํ„ฐ ํšจ์œจ์„ฑ                                               โ”‚
โ”‚     โ€ข ๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ, ์ ์€ ๊ณ„์‚ฐ๋Ÿ‰                                  โ”‚
โ”‚     โ€ข Mixtral 8x7B: 46.7B params, 12.9B active                   โ”‚
โ”‚     โ€ข Dense 70B ๊ธ‰ ์„ฑ๋Šฅ, 13B ๊ธ‰ ์†๋„                             โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. ์ „๋ฌธํ™” (Specialization)                                      โ”‚
โ”‚     โ€ข ๊ฐ ์ „๋ฌธ๊ฐ€๊ฐ€ ๋‹ค๋ฅธ ํŒจํ„ด/๋„๋ฉ”์ธ ํ•™์Šต                            โ”‚
โ”‚     โ€ข ์˜ˆ: Expert 1=์ˆ˜ํ•™, Expert 2=์ฝ”๋“œ, Expert 3=์–ธ์–ด             โ”‚
โ”‚     โ€ข ๋” ๊นŠ์€ ์ „๋ฌธ ์ง€์‹ ์ธ์ฝ”๋”ฉ ๊ฐ€๋Šฅ                                โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. ์Šค์ผ€์ผ๋ง                                                      โ”‚
โ”‚     โ€ข ์ „๋ฌธ๊ฐ€ ์ˆ˜ ๋Š˜๋ ค ๋ชจ๋ธ ํ™•์žฅ ์šฉ์ด                                โ”‚
โ”‚     โ€ข ๊ณ„์‚ฐ๋Ÿ‰ ์ฆ๊ฐ€ ์ตœ์†Œํ™”ํ•˜๋ฉฐ ์šฉ๋Ÿ‰ ์ฆ๊ฐ€                             โ”‚
โ”‚     โ€ข Google Switch Transformer: 1.6T params!                    โ”‚
โ”‚                                                                 โ”‚
โ”‚  4. ํ•™์Šต ํšจ์œจ                                                     โ”‚
โ”‚     โ€ข ๊ฐ™์€ ๊ณ„์‚ฐ๋Ÿ‰์œผ๋กœ ๋” ํฐ ๋ชจ๋ธ ํ•™์Šต ๊ฐ€๋Šฅ                          โ”‚
โ”‚     โ€ข Scaling Law ๊ด€์ ์—์„œ ์œ ๋ฆฌ                                   โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

5.2 ๋‹จ์ 

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    MoE์˜ ๋‹จ์                                      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  1. ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰                                                 โ”‚
โ”‚     โ€ข ๋ชจ๋“  ์ „๋ฌธ๊ฐ€๋ฅผ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œํ•ด์•ผ ํ•จ                           โ”‚
โ”‚     โ€ข Mixtral 8x7B: 46.7B params โ‰ˆ 93GB (FP16)                   โ”‚
โ”‚     โ€ข ์ถ”๋ก  ์‹œ ๋งŽ์€ GPU ๋ฉ”๋ชจ๋ฆฌ ํ•„์š”                                 โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. ํ•™์Šต ๋ถˆ์•ˆ์ •์„ฑ                                                 โ”‚
โ”‚     โ€ข ๋ผ์šฐํ„ฐ ํ•™์Šต์ด ์–ด๋ ค์›€                                        โ”‚
โ”‚     โ€ข ์ „๋ฌธ๊ฐ€ ๋ถˆ๊ท ํ˜• (์ผ๋ถ€๋งŒ ์‚ฌ์šฉ)                                  โ”‚
โ”‚     โ€ข Auxiliary loss ํŠœ๋‹ ํ•„์š”                                   โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. ๋ถ„์‚ฐ ํ•™์Šต ๋ณต์žก์„ฑ                                              โ”‚
โ”‚     โ€ข Expert parallelism ํ•„์š”                                    โ”‚
โ”‚     โ€ข ํ†ต์‹  ์˜ค๋ฒ„ํ—ค๋“œ                                               โ”‚
โ”‚     โ€ข ๋กœ๋“œ ๋ฐธ๋Ÿฐ์‹ฑ ์–ด๋ ค์›€                                          โ”‚
โ”‚                                                                 โ”‚
โ”‚  4. Fine-tuning ์–ด๋ ค์›€                                           โ”‚
โ”‚     โ€ข ์ „๋ฌธ๊ฐ€ specialization ์œ ์ง€ํ•˜๋ฉฐ ์ ์‘ ํ•„์š”                    โ”‚
โ”‚     โ€ข ์ผ๋ถ€ ์ „๋ฌธ๊ฐ€๋งŒ fine-tune?                                    โ”‚
โ”‚     โ€ข ์—ฐ๊ตฌ ์ง„ํ–‰ ์ค‘                                                โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

6. Mistral/Mixtral ์‹ค์Šต

6.1 Mistral 7B ์‚ฌ์šฉ

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Mistral 7B ๋กœ๋“œ
model_name = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# ํ…์ŠคํŠธ ์ƒ์„ฑ
prompt = "[INST] Explain the concept of machine learning in simple terms. [/INST]"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.7,
    do_sample=True,
    top_p=0.9,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

6.2 Mixtral 8x7B ์‚ฌ์šฉ

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Mixtral 8x7B (๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ ํ•„์š”!)
model_name = "mistralai/Mixtral-8x7B-v0.1"

# 4-bit ์–‘์žํ™”๋กœ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

# ์‚ฌ์šฉ
prompt = "[INST] Write a Python function to calculate fibonacci numbers. [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

6.3 vLLM์œผ๋กœ ํšจ์œจ์  ์„œ๋น™

from vllm import LLM, SamplingParams

# vLLM์€ MoE ๋ชจ๋ธ์„ ํšจ์œจ์ ์œผ๋กœ ์„œ๋น™
llm = LLM(
    model="mistralai/Mixtral-8x7B-v0.1",
    tensor_parallel_size=2,  # 2 GPU
    dtype="float16",
)

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=200,
)

prompts = [
    "[INST] What is machine learning? [/INST]",
    "[INST] Explain quantum computing. [/INST]",
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    print(f"Prompt: {output.prompt}")
    print(f"Response: {output.outputs[0].text}")
    print("-" * 50)

7. MoE ๋ณ€ํ˜•๋“ค

7.1 ์ฃผ์š” MoE ๋ชจ๋ธ๋“ค

๋ชจ๋ธ ์กฐ์ง ์ „๋ฌธ๊ฐ€ ์ˆ˜ Top-K ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ํ™œ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ
Switch Transformer Google 2048 1 1.6T <1B
GLaM Google 64 2 1.2T ~100B
Mixtral 8x7B Mistral 8 2 46.7B 12.9B
Mixtral 8x22B Mistral 8 2 141B 39B
DeepSeek MoE DeepSeek 64 6 145B 22B

7.2 Fine-grained MoE

"""
Fine-grained MoE: ๋” ๋งŽ์€ ์ž‘์€ ์ „๋ฌธ๊ฐ€

๊ธฐ์กด (Coarse-grained):
- 8๊ฐœ ํฐ ์ „๋ฌธ๊ฐ€, Top-2 ์„ ํƒ
- ๊ฐ ์ „๋ฌธ๊ฐ€๊ฐ€ ๋„“์€ ๋ฒ”์œ„ ๋‹ด๋‹น

Fine-grained (DeepSeek ์Šคํƒ€์ผ):
- 64๊ฐœ ์ž‘์€ ์ „๋ฌธ๊ฐ€, Top-6 ์„ ํƒ
- ๋” ์„ธ๋ฐ€ํ•œ ์ „๋ฌธํ™” ๊ฐ€๋Šฅ
- ๋ผ์šฐํŒ… ์œ ์—ฐ์„ฑ ์ฆ๊ฐ€

์žฅ์ :
- ๋” ์„ธ๋ฐ€ํ•œ ์ „๋ฌธํ™”
- ๋” ๋‚˜์€ ๋กœ๋“œ ๋ฐธ๋Ÿฐ์‹ฑ
- ํ™•์žฅ์„ฑ

๋‹จ์ :
- ๋ผ์šฐํŒ… ์˜ค๋ฒ„ํ—ค๋“œ
- ํ•™์Šต ๋ณต์žก์„ฑ
"""

์ •๋ฆฌ

Mistral ํ•ต์‹ฌ

  • Sliding Window Attention: ๋ฉ”๋ชจ๋ฆฌ O(W)๋กœ ๊ธด ์‹œํ€€์Šค ์ฒ˜๋ฆฌ
  • GQA: KV cache ํšจ์œจ์„ฑ
  • Over-training: ์ž‘์€ ๋ชจ๋ธ, ๋งŽ์€ ๋ฐ์ดํ„ฐ

MoE ํ•ต์‹ฌ

  • Sparse Activation: ํŒŒ๋ผ๋ฏธํ„ฐ ๅคš, ๊ณ„์‚ฐ ๅฐ‘
  • Router: Top-K ์ „๋ฌธ๊ฐ€ ์„ ํƒ
  • Load Balancing: ์ „๋ฌธ๊ฐ€ ๊ท ํ˜• ์œ ์ง€

์‹ค๋ฌด ์„ ํƒ ๊ฐ€์ด๋“œ

์ƒํ™ฉ ๊ถŒ์žฅ ๋ชจ๋ธ
๋‹จ์ผ GPU (16GB) Mistral 7B (4-bit)
2ร— GPU (48GB) Mixtral 8x7B (4-bit)
์„œ๋ฒ„๊ธ‰ (8ร— A100) Mixtral 8x22B
์†๋„ ์šฐ์„  Mistral 7B
์„ฑ๋Šฅ ์šฐ์„  Mixtral 8x7B+

๋‹ค์Œ ๋‹จ๊ณ„


์ฐธ๊ณ  ์ž๋ฃŒ

ํ•ต์‹ฌ ๋…ผ๋ฌธ

  • Jiang et al. (2023). "Mistral 7B"
  • Jiang et al. (2024). "Mixtral of Experts"
  • Fedus et al. (2022). "Switch Transformers: Scaling to Trillion Parameter Models"
  • Du et al. (2022). "GLaM: Efficient Scaling of Language Models"

์ฝ”๋“œ & ์ž๋ฃŒ

to navigate between lessons