LLaMA Family

LLaMA Family

ํ•™์Šต ๋ชฉํ‘œ

  • LLaMA 1/2/3์˜ ์•„ํ‚คํ…์ฒ˜ ์ง„ํ™” ์ดํ•ด
  • RoPE, RMSNorm, SwiGLU ๋“ฑ ํ•ต์‹ฌ ๊ธฐ์ˆ  ์Šต๋“
  • Grouped Query Attention (GQA) ๋ฉ”์ปค๋‹ˆ์ฆ˜ ํŒŒ์•…
  • ์‹ค๋ฌด์—์„œ์˜ LLaMA ํ™œ์šฉ๋ฒ• ํ•™์Šต

1. LLaMA ๊ฐœ์š”

1.1 LLaMA์˜ ์˜์˜

LLaMA(Large Language Model Meta AI)๋Š” 2023๋…„ Meta๊ฐ€ ๊ณต๊ฐœํ•œ ์˜คํ”ˆ์†Œ์Šค LLM์œผ๋กœ, Foundation Model ์—ฐ๊ตฌ์˜ ๋ฏผ์ฃผํ™”๋ฅผ ์ด๋Œ์—ˆ์Šต๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    LLaMA์˜ ์—ญ์‚ฌ์  ์˜์˜                            โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  Before LLaMA (2022):                                           โ”‚
โ”‚  โ€ข ์ตœ๊ณ  ์„ฑ๋Šฅ ๋ชจ๋ธ์€ API๋งŒ ์ œ๊ณต (GPT-3.5, PaLM)                    โ”‚
โ”‚  โ€ข ํ•™์ˆ  ์—ฐ๊ตฌ์šฉ ๋ชจ๋ธ์€ ์„ฑ๋Šฅ ๋ถ€์กฑ                                    โ”‚
โ”‚  โ€ข ์˜คํ”ˆ์†Œ์Šค ์ปค๋ฎค๋‹ˆํ‹ฐ์˜ LLM ์ ‘๊ทผ ์ œํ•œ์                              โ”‚
โ”‚                                                                 โ”‚
โ”‚  After LLaMA (2023):                                            โ”‚
โ”‚  โ€ข ์—ฐ๊ตฌ์ž๋“ค์ด ์ตœ์ฒจ๋‹จ ๋ชจ๋ธ ์ง์ ‘ ์‹คํ—˜ ๊ฐ€๋Šฅ                           โ”‚
โ”‚  โ€ข Alpaca, Vicuna ๋“ฑ ํŒŒ์ƒ ๋ชจ๋ธ ํญ๋ฐœ์  ์ฆ๊ฐ€                        โ”‚
โ”‚  โ€ข LLM ์—ฐ๊ตฌ ์†๋„ ๊ธ‰๊ฒฉํžˆ ๊ฐ€์†ํ™”                                    โ”‚
โ”‚                                                                 โ”‚
โ”‚  ํ•ต์‹ฌ ๊ธฐ์—ฌ:                                                       โ”‚
โ”‚  โ€ข Chinchilla ๊ทœ์น™ ์ ์šฉ (D=20N ์ด์ƒ)                             โ”‚
โ”‚  โ€ข ํšจ์œจ์  ์•„ํ‚คํ…์ฒ˜ ์„ ํƒ ๊ฒ€์ฆ                                      โ”‚
โ”‚  โ€ข ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ ๊ณต๊ฐœ                                          โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

1.2 ๋ฒ„์ „ ๋น„๊ต

ํŠน์„ฑ LLaMA 1 LLaMA 2 LLaMA 3 LLaMA 3.1 LLaMA 3.2
์ถœ์‹œ 2023.02 2023.07 2024.04 2024.07 2024.09
ํฌ๊ธฐ 7/13/33/65B 7/13/70B 8/70B 8/70/405B 1/3/11/90B
ํ† ํฐ 1.4T 2T 15T+ 15T+ 15T+
Context 2K 4K 8K 128K 128K
License ์—ฐ๊ตฌ์šฉ ์ƒ์—…์  (์กฐ๊ฑด๋ถ€) ์ƒ์—…์  (์™„ํ™”) ์ƒ์—…์  (์™„ํ™”) ์ƒ์—…์  (์™„ํ™”)
GQA โŒ โœ… (70B) โœ… (์ „์ฒด) โœ… (์ „์ฒด) โœ… (์ „์ฒด)
ํŠน์ง• ๊ธฐ๋ณธ ์•„ํ‚คํ…์ฒ˜ RLHF, Safety ๊ฐœ์„ ๋œ ์ถ”๋ก  128K ๋„ค์ดํ‹ฐ๋ธŒ, Tool Use ๋น„์ „ ๋ชจ๋ธ ์ถ”๊ฐ€

LLaMA 3.1/3.2 ์ฃผ์š” ์—…๋ฐ์ดํŠธ (2024): - LLaMA 3.1: 128K ๋„ค์ดํ‹ฐ๋ธŒ ์ปจํ…์ŠคํŠธ, 405B ํ”Œ๋ž˜๊ทธ์‹ญ ๋ชจ๋ธ, Tool Use ๊ธฐ๋Šฅ - LLaMA 3.2: ๊ฒฝ๋Ÿ‰ ๋ชจ๋ธ(1B/3B)๊ณผ ๋น„์ „ ๋ชจ๋ธ(11B/90B) ์ถ”๊ฐ€


2. LLaMA ์•„ํ‚คํ…์ฒ˜

2.1 ํ•ต์‹ฌ ๊ตฌ์„ฑ ์š”์†Œ

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    LLaMA Architecture                            โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  Input Tokens                                                   โ”‚
โ”‚       โ”‚                                                         โ”‚
โ”‚       โ–ผ                                                         โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”‚
โ”‚  โ”‚         Token Embedding                 โ”‚                    โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚       โ”‚                                                         โ”‚
โ”‚       โ–ผ                                                         โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”‚
โ”‚  โ”‚     RoPE (Rotary Position Embedding)    โ”‚  โ† ์œ„์น˜ ์ •๋ณด        โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚       โ”‚                                                         โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”‚
โ”‚  โ”‚         Transformer Block ร— N           โ”‚                    โ”‚
โ”‚  โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚  RMSNorm (Pre-normalization)      โ”‚  โ”‚  โ† LayerNorm ๋Œ€์ฒด   โ”‚
โ”‚  โ”‚  โ”‚            โ†“                      โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚  Grouped Query Attention (GQA)    โ”‚  โ”‚  โ† KV Cache ํšจ์œจ   โ”‚
โ”‚  โ”‚  โ”‚            โ†“                      โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚  Residual Connection              โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚            โ†“                      โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚  RMSNorm                          โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚            โ†“                      โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚  SwiGLU FFN                       โ”‚  โ”‚  โ† GELU ๋Œ€์ฒด        โ”‚
โ”‚  โ”‚  โ”‚            โ†“                      โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ”‚  Residual Connection              โ”‚  โ”‚                    โ”‚
โ”‚  โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ”‚                    โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚       โ”‚                                                         โ”‚
โ”‚       โ–ผ                                                         โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”‚
โ”‚  โ”‚         RMSNorm โ†’ Linear โ†’ Vocab        โ”‚                    โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚       โ”‚                                                         โ”‚
โ”‚       โ–ผ                                                         โ”‚
โ”‚  Output Logits                                                  โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2.2 ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ

"""
LLaMA ๋ชจ๋ธ ์‚ฌ์–‘ ๋น„๊ต
"""

LLAMA_CONFIGS = {
    "llama-7b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 32,  # MHA (GQA ๋ฏธ์‚ฌ์šฉ)
        "vocab_size": 32000,
        "ffn_dim": 11008,  # ์•ฝ 2.67 ร— dim
        "context_length": 2048,
    },
    "llama-13b": {
        "dim": 5120,
        "n_layers": 40,
        "n_heads": 40,
        "n_kv_heads": 40,
        "vocab_size": 32000,
        "ffn_dim": 13824,
        "context_length": 2048,
    },
    "llama-70b": {
        "dim": 8192,
        "n_layers": 80,
        "n_heads": 64,
        "n_kv_heads": 8,  # GQA! 8๊ฐœ KV heads
        "vocab_size": 32000,
        "ffn_dim": 28672,
        "context_length": 4096,
    },
    "llama3-8b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 8,  # GQA
        "vocab_size": 128256,  # ํ™•์žฅ๋œ vocab
        "ffn_dim": 14336,
        "context_length": 8192,
    },
    "llama3-70b": {
        "dim": 8192,
        "n_layers": 80,
        "n_heads": 64,
        "n_kv_heads": 8,  # GQA
        "vocab_size": 128256,
        "ffn_dim": 28672,
        "context_length": 8192,
    },
    # LLaMA 3.1 (2024.07)
    "llama3.1-8b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 14336,
        "context_length": 131072,  # 128K ๋„ค์ดํ‹ฐ๋ธŒ
    },
    "llama3.1-405b": {
        "dim": 16384,
        "n_layers": 126,
        "n_heads": 128,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 53248,
        "context_length": 131072,  # 128K ๋„ค์ดํ‹ฐ๋ธŒ
    },
    # LLaMA 3.2 (2024.09) - ๊ฒฝ๋Ÿ‰ ํ…์ŠคํŠธ ๋ชจ๋ธ
    "llama3.2-1b": {
        "dim": 2048,
        "n_layers": 16,
        "n_heads": 32,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 8192,
        "context_length": 131072,
    },
    "llama3.2-3b": {
        "dim": 3072,
        "n_layers": 28,
        "n_heads": 24,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 8192,
        "context_length": 131072,
    },
}

3. RoPE (Rotary Position Embedding)

3.1 ๊ฐœ๋…

RoPE๋Š” ์œ„์น˜ ์ •๋ณด๋ฅผ ํšŒ์ „ ํ–‰๋ ฌ๋กœ ์ธ์ฝ”๋”ฉํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Position Encoding ๋น„๊ต                        โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  1. Sinusoidal (Transformer ์›๋ณธ)                               โ”‚
โ”‚     PE(pos, 2i) = sin(pos / 10000^(2i/d))                       โ”‚
โ”‚     PE(pos, 2i+1) = cos(pos / 10000^(2i/d))                     โ”‚
โ”‚     โ†’ ์ž…๋ ฅ์— ๋”ํ•จ (additive)                                     โ”‚
โ”‚     โ†’ ์ƒ๋Œ€ ์œ„์น˜ ์ •๋ณด ์•ฝํ•จ                                         โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. Learned (BERT, GPT)                                         โ”‚
โ”‚     PE = Embedding(position)                                    โ”‚
โ”‚     โ†’ ํ•™์Šต๋œ ๋ฒกํ„ฐ                                                 โ”‚
โ”‚     โ†’ ํ•™์Šต ์ค‘ ๋ณธ ๊ธธ์ด ์ด์ƒ ์ผ๋ฐ˜ํ™” ์–ด๋ ค์›€                           โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. RoPE (LLaMA)                                                โ”‚
โ”‚     R(ฮธ) = ํšŒ์ „ ํ–‰๋ ฌ, ฮธ = f(position)                           โ”‚
โ”‚     q' = R(ฮธ_m) ร— q, k' = R(ฮธ_n) ร— k                           โ”‚
โ”‚     q' ยท k' = q ยท k ร— cos(ฮธ_m - ฮธ_n)                           โ”‚
โ”‚     โ†’ ์ƒ๋Œ€ ์œ„์น˜ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ธ์ฝ”๋”ฉ                                  โ”‚
โ”‚     โ†’ ๊ธธ์ด ์™ธ์‚ฝ ๊ฐ€๋Šฅ (with modifications)                        โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

3.2 ์ˆ˜ํ•™์  ์ดํ•ด

import torch
import math

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    """
    RoPE๋ฅผ ์œ„ํ•œ ๋ณต์†Œ์ˆ˜ ์ฃผํŒŒ์ˆ˜ ์‚ฌ์ „ ๊ณ„์‚ฐ

    Args:
        dim: ์ž„๋ฒ ๋”ฉ ์ฐจ์› (head_dim)
        seq_len: ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด
        theta: ๊ธฐ๋ณธ ์ฃผํŒŒ์ˆ˜ (10000)

    Returns:
        freqs_cis: (seq_len, dim//2) ๋ณต์†Œ์ˆ˜ ํ…์„œ
    """
    # ์ฃผํŒŒ์ˆ˜ ๊ณ„์‚ฐ: ฮธ_i = 1 / (theta^(2i/d))
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    # ์œ„์น˜๋ณ„ ๊ฐ๋„: m * ฮธ_i
    t = torch.arange(seq_len)
    freqs = torch.outer(t, freqs)  # (seq_len, dim//2)
    # ๋ณต์†Œ์ˆ˜ ํ˜•ํƒœ: e^(i*ฮธ) = cos(ฮธ) + i*sin(ฮธ)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """
    Query์™€ Key์— RoPE ์ ์šฉ

    Args:
        xq: Query (batch, seq_len, n_heads, head_dim)
        xk: Key (batch, seq_len, n_kv_heads, head_dim)
        freqs_cis: ์‚ฌ์ „ ๊ณ„์‚ฐ๋œ ๋ณต์†Œ์ˆ˜ ์ฃผํŒŒ์ˆ˜

    Returns:
        ํšŒ์ „๋œ Query์™€ Key
    """
    # ์‹ค์ˆ˜๋ฅผ ๋ณต์†Œ์ˆ˜๋กœ ๋ณ€ํ™˜ (์ธ์ ‘ํ•œ 2๊ฐœ์”ฉ ๋ฌถ์Œ)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # ํšŒ์ „ ์ ์šฉ (๋ณต์†Œ์ˆ˜ ๊ณฑ)
    freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)  # (1, seq, 1, dim//2)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

# ์˜ˆ์‹œ
batch, seq_len, n_heads, head_dim = 2, 128, 32, 128
xq = torch.randn(batch, seq_len, n_heads, head_dim)
xk = torch.randn(batch, seq_len, n_heads, head_dim)
freqs_cis = precompute_freqs_cis(head_dim, seq_len)

xq_rope, xk_rope = apply_rotary_emb(xq, xk, freqs_cis)
print(f"Output shape: {xq_rope.shape}")  # (2, 128, 32, 128)

3.3 RoPE์˜ ์žฅ์ 

"""
RoPE์˜ ์žฅ์ :

1. ์ƒ๋Œ€ ์œ„์น˜ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ธ์ฝ”๋”ฉ
   - q_m ยท k_n โˆ cos(ฮธ_m - ฮธ_n)
   - ์ ˆ๋Œ€ ์œ„์น˜๊ฐ€ ์•„๋‹Œ ์ƒ๋Œ€ ๊ฑฐ๋ฆฌ ์˜์กด

2. ์™ธ์‚ฝ ๊ฐ€๋Šฅ์„ฑ
   - ํ•™์Šต ์‹œ ๋ณธ ๊ธธ์ด ์ด์ƒ์œผ๋กœ ํ™•์žฅ ๊ฐ€๋Šฅ
   - (๋‹จ, ์„ฑ๋Šฅ ์ €ํ•˜ ์žˆ์Œ โ†’ NTK, YaRN์œผ๋กœ ๊ฐœ์„ )

3. ํšจ์œจ์„ฑ
   - ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ ์—†์Œ
   - Element-wise ์—ฐ์‚ฐ์œผ๋กœ ๋น ๋ฆ„

4. ์„ ํ˜• Self-attention๊ณผ ํ˜ธํ™˜
   - ์ผ๋ถ€ ํšจ์œจ์  attention ๋ฐฉ์‹๊ณผ ๊ฒฐํ•ฉ ๊ฐ€๋Šฅ
"""

4. RMSNorm

4.1 ๊ฐœ๋…

RMSNorm์€ LayerNorm์˜ ๋‹จ์ˆœํ™” ๋ฒ„์ „์œผ๋กœ, ํ‰๊ท  ๊ณ„์‚ฐ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    LayerNorm vs RMSNorm                          โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  LayerNorm:                                                     โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                             โ”‚
โ”‚  ฮผ = mean(x)                                                    โ”‚
โ”‚  ฯƒ = std(x)                                                     โ”‚
โ”‚  y = ฮณ ร— (x - ฮผ) / ฯƒ + ฮฒ                                        โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ€ข ํ‰๊ท  ๋นผ๊ธฐ + ๋ถ„์‚ฐ์œผ๋กœ ๋‚˜๋ˆ„๊ธฐ                                     โ”‚
โ”‚  โ€ข ํ•™์Šต ๊ฐ€๋Šฅํ•œ scale(ฮณ)์™€ shift(ฮฒ)                               โ”‚
โ”‚                                                                 โ”‚
โ”‚  RMSNorm:                                                       โ”‚
โ”‚  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€                             โ”‚
โ”‚  RMS(x) = sqrt(mean(x^2))                                       โ”‚
โ”‚  y = ฮณ ร— x / RMS(x)                                             โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ€ข ํ‰๊ท  ๋นผ๊ธฐ ์—†์Œ โ†’ Re-centering ์ œ๊ฑฐ                             โ”‚
โ”‚  โ€ข RMS๋กœ๋งŒ ์Šค์ผ€์ผ๋ง                                               โ”‚
โ”‚  โ€ข shift(ฮฒ) ์—†์Œ                                                 โ”‚
โ”‚  โ€ข ์—ฐ์‚ฐ๋Ÿ‰ ๊ฐ์†Œ, ์„ฑ๋Šฅ ์œ ์‚ฌ                                          โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

4.2 ๊ตฌํ˜„

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization

    ๋…ผ๋ฌธ: https://arxiv.org/abs/1910.07467
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # scale parameter ฮณ

    def _norm(self, x):
        # RMS = sqrt(mean(x^2))
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # output = ฮณ ร— (x / RMS(x))
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# LayerNorm๊ณผ ๋น„๊ต
x = torch.randn(2, 10, 512)

layer_norm = nn.LayerNorm(512)
rms_norm = RMSNorm(512)

# ์—ฐ์‚ฐ ์‹œ๊ฐ„ ๋น„๊ต (RMSNorm์ด ์•ฝ๊ฐ„ ๋น ๋ฆ„)
import time

start = time.time()
for _ in range(1000):
    _ = layer_norm(x)
print(f"LayerNorm: {time.time() - start:.4f}s")

start = time.time()
for _ in range(1000):
    _ = rms_norm(x)
print(f"RMSNorm: {time.time() - start:.4f}s")

5. SwiGLU

5.1 ๊ฐœ๋…

SwiGLU๋Š” GLU(Gated Linear Unit)์˜ ๋ณ€ํ˜•์œผ๋กœ, Swish ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    FFN ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ๋น„๊ต                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  1. ReLU FFN (Transformer ์›๋ณธ):                                โ”‚
โ”‚     FFN(x) = max(0, xWโ‚ + bโ‚)Wโ‚‚ + bโ‚‚                            โ”‚
โ”‚     โ€ข ๋‹จ์ˆœํ•˜์ง€๋งŒ ์Œ์ˆ˜ ์˜์—ญ ์ •๋ณด ์†์‹ค                               โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. GELU FFN (BERT, GPT):                                       โ”‚
โ”‚     FFN(x) = GELU(xWโ‚)Wโ‚‚                                        โ”‚
โ”‚     GELU(x) = x ร— ฮฆ(x)  (ฮฆ = CDF of normal)                     โ”‚
โ”‚     โ€ข ๋ถ€๋“œ๋Ÿฌ์šด ํ™œ์„ฑํ™”, ์„ฑ๋Šฅ ํ–ฅ์ƒ                                   โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. SwiGLU FFN (LLaMA):                                         โ”‚
โ”‚     FFN(x) = (Swish(xWโ‚) โŠ™ xV)Wโ‚‚                                โ”‚
โ”‚     Swish(x) = x ร— ฯƒ(x)  (ฯƒ = sigmoid)                          โ”‚
โ”‚     โŠ™ = element-wise multiplication                             โ”‚
โ”‚                                                                 โ”‚
โ”‚     โ€ข Gating mechanism์œผ๋กœ ์ •๋ณด ํ๋ฆ„ ์ œ์–ด                         โ”‚
โ”‚     โ€ข ๋” ๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ, ๋” ์ข‹์€ ์„ฑ๋Šฅ                               โ”‚
โ”‚     โ€ข 2/3 ร— 4d hidden dim (ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ์œ ์ง€)                      โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

5.2 ๊ตฌํ˜„

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    """
    SwiGLU: Swish-Gated Linear Unit

    FFN(x) = (Swish(xWโ‚) โŠ™ xV) Wโ‚‚

    ๋…ผ๋ฌธ: https://arxiv.org/abs/2002.05202
    """
    def __init__(self, dim: int, hidden_dim: int = None, multiple_of: int = 256):
        super().__init__()

        # hidden_dim ๊ณ„์‚ฐ: 2/3 ร— 4d, 256์˜ ๋ฐฐ์ˆ˜๋กœ ๋ฐ˜์˜ฌ๋ฆผ
        if hidden_dim is None:
            hidden_dim = int(2 * (4 * dim) / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # gate
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # down projection
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # up projection

    def forward(self, x):
        # SwiGLU: Swish(xWโ‚) โŠ™ (xWโ‚ƒ) โ†’ Wโ‚‚
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# ๊ธฐ์กด FFN๊ณผ ๋น„๊ต
class StandardFFN(nn.Module):
    def __init__(self, dim: int, hidden_dim: int = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.w2(F.gelu(self.w1(x)))

# ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๋น„๊ต
dim = 4096
swiglu = SwiGLU(dim)  # 3๊ฐœ์˜ linear: dimโ†’hidden, dimโ†’hidden, hiddenโ†’dim
standard = StandardFFN(dim)  # 2๊ฐœ์˜ linear: dimโ†’4*dim, 4*dimโ†’dim

print(f"SwiGLU params: {sum(p.numel() for p in swiglu.parameters()):,}")
print(f"Standard FFN params: {sum(p.numel() for p in standard.parameters()):,}")
# SwiGLU๊ฐ€ ์•ฝ๊ฐ„ ๋” ๋งŽ์ง€๋งŒ hidden_dim ์กฐ์ •์œผ๋กœ ๋น„์Šทํ•˜๊ฒŒ ๋งž์ถค

6. Grouped Query Attention (GQA)

6.1 ๊ฐœ๋…

GQA๋Š” Multi-Head Attention๊ณผ Multi-Query Attention์˜ ์ค‘๊ฐ„ ํ˜•ํƒœ์ž…๋‹ˆ๋‹ค.

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Attention ๋ฐฉ์‹ ๋น„๊ต                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  1. Multi-Head Attention (MHA):                                 โ”‚
โ”‚     Q heads: 32  โ”‚  K heads: 32  โ”‚  V heads: 32                 โ”‚
โ”‚     โ€ข ๊ฐ Q head๊ฐ€ ๋…๋ฆฝ์ ์ธ KV head ์‚ฌ์šฉ                           โ”‚
โ”‚     โ€ข ๋ฉ”๋ชจ๋ฆฌ: ๋งŽ์Œ (32 ร— KV cache)                                โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. Multi-Query Attention (MQA):                                โ”‚
โ”‚     Q heads: 32  โ”‚  K heads: 1   โ”‚  V heads: 1                  โ”‚
โ”‚     โ€ข ๋ชจ๋“  Q head๊ฐ€ ๊ฐ™์€ KV ๊ณต์œ                                   โ”‚
โ”‚     โ€ข ๋ฉ”๋ชจ๋ฆฌ: ์ตœ์†Œ (1 ร— KV cache)                                 โ”‚
โ”‚     โ€ข ํ’ˆ์งˆ: MHA๋ณด๋‹ค ์•ฝ๊ฐ„ ๋‚ฎ์Œ                                     โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. Grouped Query Attention (GQA):                              โ”‚
โ”‚     Q heads: 32  โ”‚  K heads: 8   โ”‚  V heads: 8                  โ”‚
โ”‚     โ€ข Q heads๋ฅผ ๊ทธ๋ฃน์œผ๋กœ ๋‚˜๋ˆ  KV ๊ณต์œ                               โ”‚
โ”‚     โ€ข ์˜ˆ: 4๊ฐœ์˜ Q head๊ฐ€ 1๊ฐœ์˜ KV head ๊ณต์œ                        โ”‚
โ”‚     โ€ข ๋ฉ”๋ชจ๋ฆฌ: ์ค‘๊ฐ„ (8 ร— KV cache)                                 โ”‚
โ”‚     โ€ข ํ’ˆ์งˆ: MHA์™€ ๊ฑฐ์˜ ๋™์ผ                                       โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”               โ”‚
โ”‚  โ”‚ MHA          โ”‚ MQA           โ”‚ GQA           โ”‚               โ”‚
โ”‚  โ”‚ Q Q Q Q Q Q  โ”‚ Q Q Q Q Q Q   โ”‚ Q Qโ”‚Q Qโ”‚Q Q   โ”‚               โ”‚
โ”‚  โ”‚ โ†“ โ†“ โ†“ โ†“ โ†“ โ†“  โ”‚ โ†“ โ†“ โ†“ โ†“ โ†“ โ†“   โ”‚ โ†“ โ†“โ”‚โ†“ โ†“โ”‚โ†“ โ†“   โ”‚               โ”‚
โ”‚  โ”‚ K K K K K K  โ”‚     K         โ”‚  K โ”‚ K โ”‚ K    โ”‚               โ”‚
โ”‚  โ”‚ V V V V V V  โ”‚     V         โ”‚  V โ”‚ V โ”‚ V    โ”‚               โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜               โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

6.2 ๊ตฌํ˜„

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention (GQA)

    ๋…ผ๋ฌธ: https://arxiv.org/abs/2305.13245
    """
    def __init__(
        self,
        dim: int,
        n_heads: int = 32,
        n_kv_heads: int = 8,  # KV heads ์ˆ˜ (< n_heads)
        head_dim: int = None,
    ):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim or dim // n_heads

        # Q heads > KV heads ๊ฒ€์ฆ
        assert n_heads % n_kv_heads == 0
        self.n_rep = n_heads // n_kv_heads  # ๊ฐ KV head๊ฐ€ ๋‹ด๋‹นํ•˜๋Š” Q head ์ˆ˜

        self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)

    def forward(self, x, freqs_cis=None, mask=None, kv_cache=None):
        batch, seq_len, _ = x.shape

        # Q, K, V ๊ณ„์‚ฐ
        xq = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim)
        xk = self.wk(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
        xv = self.wv(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)

        # RoPE ์ ์šฉ (์žˆ๋Š” ๊ฒฝ์šฐ)
        if freqs_cis is not None:
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

        # KV Cache ์ฒ˜๋ฆฌ (์ถ”๋ก  ์‹œ)
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            xk = torch.cat([cache_k, xk], dim=1)
            xv = torch.cat([cache_v, xv], dim=1)

        # KV heads ํ™•์žฅ: n_kv_heads โ†’ n_heads
        # (batch, seq, n_kv_heads, head_dim) โ†’ (batch, seq, n_heads, head_dim)
        xk = self._repeat_kv(xk)
        xv = self._repeat_kv(xv)

        # Attention ๊ณ„์‚ฐ
        xq = xq.transpose(1, 2)  # (batch, n_heads, seq, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores + mask

        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, xv)

        # ๊ฒฐ๊ณผ ํ•ฉ์น˜๊ธฐ
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(output), (xk, xv)

    def _repeat_kv(self, x):
        """KV heads๋ฅผ Q heads ์ˆ˜๋งŒํผ ๋ฐ˜๋ณต"""
        if self.n_rep == 1:
            return x
        batch, seq_len, n_kv_heads, head_dim = x.shape
        x = x[:, :, :, None, :].expand(batch, seq_len, n_kv_heads, self.n_rep, head_dim)
        return x.reshape(batch, seq_len, n_kv_heads * self.n_rep, head_dim)

# ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋น„๊ต
def compare_kv_cache_memory(seq_len, batch_size=1, dtype_bytes=2):
    """KV cache ๋ฉ”๋ชจ๋ฆฌ ๋น„๊ต (FP16 ๊ธฐ์ค€)"""
    configs = {
        "LLaMA-70B (MHA)": {"n_layers": 80, "n_kv_heads": 64, "head_dim": 128},
        "LLaMA-70B (GQA)": {"n_layers": 80, "n_kv_heads": 8, "head_dim": 128},
    }

    for name, cfg in configs.items():
        kv_mem = (2 *  # K and V
                  batch_size *
                  cfg["n_layers"] *
                  seq_len *
                  cfg["n_kv_heads"] *
                  cfg["head_dim"] *
                  dtype_bytes)
        print(f"{name}: {kv_mem / 1e9:.2f} GB for {seq_len} tokens")

compare_kv_cache_memory(4096)
# LLaMA-70B (MHA): 5.24 GB for 4096 tokens
# LLaMA-70B (GQA): 0.66 GB for 4096 tokens  โ† 8๋ฐฐ ์ ˆ์•ฝ!

7. LLaMA ์‹ค์Šต

7.1 HuggingFace๋กœ ์‚ฌ์šฉํ•˜๊ธฐ

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# LLaMA 2 7B ๋กœ๋“œ
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# ํ…์ŠคํŠธ ์ƒ์„ฑ
def generate_text(prompt, max_new_tokens=100, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# ์‚ฌ์šฉ ์˜ˆ์‹œ
prompt = "Explain the concept of machine learning in simple terms:"
response = generate_text(prompt)
print(response)

7.2 ์–‘์žํ™”๋กœ ํšจ์œจ์  ์‚ฌ์šฉ

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# 4-bit ์–‘์žํ™” ์„ค์ •
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",          # NormalFloat4
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,     # ์ด์ค‘ ์–‘์žํ™”
)

# ์–‘์žํ™”๋œ ๋ชจ๋ธ ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto"
)

# ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ํ™•์ธ
print(f"Model memory: {model.get_memory_footprint() / 1e9:.2f} GB")
# ์•ฝ 4GB (FP16 ๋Œ€๋น„ 75% ์ ˆ์•ฝ)

7.3 LLaMA 3 ์‚ฌ์šฉ

# LLaMA 3 8B (128K ํ† ํฌ๋‚˜์ด์ €)
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # LLaMA 3๋Š” bfloat16 ๊ถŒ์žฅ
    device_map="auto"
)

# LLaMA 3 ํŠน์ง•:
# - 128K ํ† ํฌ๋‚˜์ด์ € (๋” ํšจ์œจ์ )
# - 8K ๊ธฐ๋ณธ ์ปจํ…์ŠคํŠธ (128K๊นŒ์ง€ ํ™•์žฅ ๊ฐ€๋Šฅ)
# - ๊ฐœ์„ ๋œ ์ถ”๋ก  ๋Šฅ๋ ฅ

prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

8. LLaMA 3.1/3.2 ์ƒ์„ธ

8.1 LLaMA 3.1 (2024๋…„ 7์›”)

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    LLaMA 3.1 ์ฃผ์š” ํŠน์ง•                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  1. 128K ๋„ค์ดํ‹ฐ๋ธŒ ์ปจํ…์ŠคํŠธ                                        โ”‚
โ”‚     โ€ข ํ•™์Šต ์‹œ๋ถ€ํ„ฐ 128K ํ† ํฐ ์ง€์›                                  โ”‚
โ”‚     โ€ข RoPE scaling ์—†์ด ๊ธด ๋ฌธ๋งฅ ์ฒ˜๋ฆฌ                             โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. 405B ํ”Œ๋ž˜๊ทธ์‹ญ ๋ชจ๋ธ                                           โ”‚
โ”‚     โ€ข GPT-4 ์ˆ˜์ค€ ์„ฑ๋Šฅ                                            โ”‚
โ”‚     โ€ข 126๊ฐœ ๋ ˆ์ด์–ด, 16K ์ž„๋ฒ ๋”ฉ ์ฐจ์›                               โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. Tool Use ๊ธฐ๋Šฅ                                                โ”‚
โ”‚     โ€ข ํ•จ์ˆ˜ ํ˜ธ์ถœ (Function Calling)                               โ”‚
โ”‚     โ€ข ์ฝ”๋“œ ์ธํ„ฐํ”„๋ฆฌํ„ฐ                                            โ”‚
โ”‚     โ€ข ๊ฒ€์ƒ‰ ๋„๊ตฌ ํ†ตํ•ฉ                                             โ”‚
โ”‚                                                                 โ”‚
โ”‚  4. ๋‹ค๊ตญ์–ด ์ง€์› ๊ฐ•ํ™”                                             โ”‚
โ”‚     โ€ข ์˜์–ด, ๋…์ผ์–ด, ํ”„๋ž‘์Šค์–ด, ์ดํƒˆ๋ฆฌ์•„์–ด                          โ”‚
โ”‚     โ€ข ํฌ๋ฅดํˆฌ๊ฐˆ์–ด, ํžŒ๋””์–ด, ์ŠคํŽ˜์ธ์–ด, ํƒœ๊ตญ์–ด                        โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
# LLaMA 3.1 Tool Use ์˜ˆ์ œ
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Tool Use ํ˜•์‹ (LLaMA 3.1 ํŠนํ™”)
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_weather",
            "description": "Get current weather for a location",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {"type": "string", "description": "City name"}
                },
                "required": ["location"]
            }
        }
    }
]

messages = [
    {"role": "system", "content": "You are a helpful assistant with access to tools."},
    {"role": "user", "content": "What's the weather in Seoul?"}
]

# Tool ํ˜ธ์ถœ ์ƒ์„ฑ
inputs = tokenizer.apply_chat_template(
    messages,
    tools=tools,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))

8.2 LLaMA 3.2 (2024๋…„ 9์›”)

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    LLaMA 3.2 ๋ชจ๋ธ ๋ผ์ธ์—…                          โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                 โ”‚
โ”‚  ๊ฒฝ๋Ÿ‰ ํ…์ŠคํŠธ ๋ชจ๋ธ (on-device ์ตœ์ ํ™”):                             โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                โ”‚
โ”‚  โ”‚  LLaMA 3.2 1B: ๋ชจ๋ฐ”์ผ/์—์ง€ ๋””๋ฐ”์ด์Šค์šฉ         โ”‚                โ”‚
โ”‚  โ”‚  LLaMA 3.2 3B: ๊ฒฝ๋Ÿ‰ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์šฉ            โ”‚                โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                โ”‚
โ”‚                                                                 โ”‚
โ”‚  ๋น„์ „ ๋ชจ๋ธ (๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ):                                            โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                โ”‚
โ”‚  โ”‚  LLaMA 3.2 11B-Vision: ์ด๋ฏธ์ง€ ์ดํ•ด           โ”‚                โ”‚
โ”‚  โ”‚  LLaMA 3.2 90B-Vision: ๊ณ ์„ฑ๋Šฅ ๋น„์ „ ํƒœ์Šคํฌ    โ”‚                โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                โ”‚
โ”‚                                                                 โ”‚
โ”‚  ํŠน์ง•:                                                           โ”‚
โ”‚  โ€ข 1B/3B: 128K ์ปจํ…์ŠคํŠธ, on-device ์ถ”๋ก  ๊ฐ€๋Šฅ                     โ”‚
โ”‚  โ€ข 11B/90B: ๋น„์ „ ์ธ์ฝ”๋” ํ†ตํ•ฉ, ์ด๋ฏธ์ง€+ํ…์ŠคํŠธ ์ฒ˜๋ฆฌ                   โ”‚
โ”‚  โ€ข Qualcomm, MediaTek ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”                            โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
# LLaMA 3.2 Vision ์˜ˆ์ œ
from transformers import MllamaForConditionalGeneration, AutoProcessor
from PIL import Image
import requests

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

# ์ด๋ฏธ์ง€ ๋กœ๋“œ
url = "https://example.com/image.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# ๋น„์ „ ๋Œ€ํ™”
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "์ด ์ด๋ฏธ์ง€์—์„œ ๋ฌด์—‡์ด ๋ณด์ด๋‚˜์š”?"}
        ]
    }
]

input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, return_tensors="pt").to(model.device)

output = model.generate(**inputs, max_new_tokens=256)
print(processor.decode(output[0]))

์ •๋ฆฌ

LLaMA ํ•ต์‹ฌ ๊ธฐ์ˆ 

๊ธฐ์ˆ  ํšจ๊ณผ
RoPE ์ƒ๋Œ€ ์œ„์น˜ ์ธ์ฝ”๋”ฉ, ๊ธธ์ด ํ™•์žฅ ๊ฐ€๋Šฅ
RMSNorm LayerNorm๋ณด๋‹ค ๋น ๋ฅด๊ณ  ํšจ๊ณผ์ 
SwiGLU GELU๋ณด๋‹ค ์ข‹์€ ์„ฑ๋Šฅ
GQA KV cache ๋ฉ”๋ชจ๋ฆฌ 8๋ฐฐ ์ ˆ์•ฝ

์‹ค๋ฌด ๊ถŒ์žฅ ์‚ฌํ•ญ

  1. 7B/8B: ๋‹จ์ผ GPU (16GB+), ๋น ๋ฅธ ์‹คํ—˜์šฉ
  2. 13B: 24GB GPU, ๊ท ํ˜• ์žกํžŒ ์„ ํƒ
  3. 70B: ์—ฌ๋Ÿฌ GPU, ์ตœ๊ณ  ์„ฑ๋Šฅ ํ•„์š” ์‹œ
  4. ์–‘์žํ™”: 4-bit์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ 75% ์ ˆ์•ฝ

๋‹ค์Œ ๋‹จ๊ณ„


์ฐธ๊ณ  ์ž๋ฃŒ

ํ•ต์‹ฌ ๋…ผ๋ฌธ

  • Touvron et al. (2023). "LLaMA: Open and Efficient Foundation Language Models"
  • Touvron et al. (2023). "LLaMA 2: Open Foundation and Fine-Tuned Chat Models"
  • Su et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding"
  • Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models"

์ฝ”๋“œ & ์ž๋ฃŒ

to navigate between lessons