LLaMA Family

LLaMA Family

Learning Objectives

  • Understand the architectural evolution of LLaMA 1/2/3
  • Master core technologies: RoPE, RMSNorm, SwiGLU
  • Grasp the Grouped Query Attention (GQA) mechanism
  • Learn practical LLaMA usage methods

1. LLaMA Overview

1.1 Significance of LLaMA

LLaMA (Large Language Model Meta AI) is an open-source LLM released by Meta in 2023, leading the democratization of Foundation Model research.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    Historical Significance of LLaMA              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  Before LLaMA (2022):                                           β”‚
β”‚  β€’ Best-performing models only available via API (GPT-3.5, PaLM)β”‚
β”‚  β€’ Academic research models lacked performance                  β”‚
β”‚  β€’ Limited LLM access for open-source community                 β”‚
β”‚                                                                 β”‚
β”‚  After LLaMA (2023):                                            β”‚
β”‚  β€’ Researchers can directly experiment with state-of-the-art    β”‚
β”‚  β€’ Explosive growth of derivative models (Alpaca, Vicuna, etc.) β”‚
β”‚  β€’ Rapid acceleration of LLM research                           β”‚
β”‚                                                                 β”‚
β”‚  Key Contributions:                                              β”‚
β”‚  β€’ Applied Chinchilla rules (D=20N or more)                     β”‚
β”‚  β€’ Validated efficient architecture choices                      β”‚
β”‚  β€’ Published training data composition                           β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 Version Comparison

Feature LLaMA 1 LLaMA 2 LLaMA 3 LLaMA 3.1 LLaMA 3.2
Release 2023.02 2023.07 2024.04 2024.07 2024.09
Sizes 7/13/33/65B 7/13/70B 8/70B 8/70/405B 1/3/11/90B
Tokens 1.4T 2T 15T+ 15T+ 15T+
Context 2K 4K 8K 128K 128K
License Research Commercial (conditional) Commercial (relaxed) Commercial (relaxed) Commercial (relaxed)
GQA ❌ βœ… (70B) βœ… (all) βœ… (all) βœ… (all)
Features Base architecture RLHF, Safety Improved reasoning 128K native, Tool Use Vision models added

LLaMA 3.1/3.2 Major Updates (2024): - LLaMA 3.1: 128K native context, 405B flagship model, Tool Use capability - LLaMA 3.2: Lightweight models (1B/3B) and Vision models (11B/90B) added


2. LLaMA Architecture

2.1 Core Components

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    LLaMA Architecture                            β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  Input Tokens                                                   β”‚
β”‚       β”‚                                                         β”‚
β”‚       β–Ό                                                         β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”‚
β”‚  β”‚         Token Embedding                 β”‚                    β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β”‚
β”‚       β”‚                                                         β”‚
β”‚       β–Ό                                                         β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”‚
β”‚  β”‚     RoPE (Rotary Position Embedding)    β”‚  ← Position info   β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β”‚
β”‚       β”‚                                                         β”‚
β”‚  β”Œβ”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”‚
β”‚  β”‚         Transformer Block Γ— N           β”‚                    β”‚
β”‚  β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”‚                    β”‚
β”‚  β”‚  β”‚  RMSNorm (Pre-normalization)      β”‚  β”‚  ← Replaces LN     β”‚
β”‚  β”‚  β”‚            ↓                      β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚  Grouped Query Attention (GQA)    β”‚  β”‚  ← KV Cache eff.   β”‚
β”‚  β”‚  β”‚            ↓                      β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚  Residual Connection              β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚            ↓                      β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚  RMSNorm                          β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚            ↓                      β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚  SwiGLU FFN                       β”‚  β”‚  ← Replaces GELU   β”‚
β”‚  β”‚  β”‚            ↓                      β”‚  β”‚                    β”‚
β”‚  β”‚  β”‚  Residual Connection              β”‚  β”‚                    β”‚
β”‚  β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β”‚                    β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β”‚
β”‚       β”‚                                                         β”‚
β”‚       β–Ό                                                         β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”‚
β”‚  β”‚         RMSNorm β†’ Linear β†’ Vocab        β”‚                    β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β”‚
β”‚       β”‚                                                         β”‚
β”‚       β–Ό                                                         β”‚
β”‚  Output Logits                                                  β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2.2 Hyperparameters

"""
LLaMA Model Specification Comparison
"""

LLAMA_CONFIGS = {
    "llama-7b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 32,  # MHA (no GQA)
        "vocab_size": 32000,
        "ffn_dim": 11008,  # approx 2.67 Γ— dim
        "context_length": 2048,
    },
    "llama-13b": {
        "dim": 5120,
        "n_layers": 40,
        "n_heads": 40,
        "n_kv_heads": 40,
        "vocab_size": 32000,
        "ffn_dim": 13824,
        "context_length": 2048,
    },
    "llama-70b": {
        "dim": 8192,
        "n_layers": 80,
        "n_heads": 64,
        "n_kv_heads": 8,  # GQA! 8 KV heads
        "vocab_size": 32000,
        "ffn_dim": 28672,
        "context_length": 4096,
    },
    "llama3-8b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 8,  # GQA
        "vocab_size": 128256,  # Extended vocab
        "ffn_dim": 14336,
        "context_length": 8192,
    },
    "llama3-70b": {
        "dim": 8192,
        "n_layers": 80,
        "n_heads": 64,
        "n_kv_heads": 8,  # GQA
        "vocab_size": 128256,
        "ffn_dim": 28672,
        "context_length": 8192,
    },
    # LLaMA 3.1 (2024.07)
    "llama3.1-8b": {
        "dim": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 14336,
        "context_length": 131072,  # 128K native
    },
    "llama3.1-405b": {
        "dim": 16384,
        "n_layers": 126,
        "n_heads": 128,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 53248,
        "context_length": 131072,  # 128K native
    },
    # LLaMA 3.2 (2024.09) - Lightweight text models
    "llama3.2-1b": {
        "dim": 2048,
        "n_layers": 16,
        "n_heads": 32,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 8192,
        "context_length": 131072,
    },
    "llama3.2-3b": {
        "dim": 3072,
        "n_layers": 28,
        "n_heads": 24,
        "n_kv_heads": 8,
        "vocab_size": 128256,
        "ffn_dim": 8192,
        "context_length": 131072,
    },
}

3. RoPE (Rotary Position Embedding)

3.1 Concept

RoPE is a method that encodes positional information using rotation matrices.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    Position Encoding Comparison                  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  1. Sinusoidal (Original Transformer)                           β”‚
β”‚     PE(pos, 2i) = sin(pos / 10000^(2i/d))                       β”‚
β”‚     PE(pos, 2i+1) = cos(pos / 10000^(2i/d))                     β”‚
β”‚     β†’ Added to input (additive)                                 β”‚
β”‚     β†’ Weak relative position information                        β”‚
β”‚                                                                 β”‚
β”‚  2. Learned (BERT, GPT)                                         β”‚
β”‚     PE = Embedding(position)                                    β”‚
β”‚     β†’ Learned vectors                                           β”‚
β”‚     β†’ Difficult to generalize beyond trained lengths            β”‚
β”‚                                                                 β”‚
β”‚  3. RoPE (LLaMA)                                                β”‚
β”‚     R(ΞΈ) = rotation matrix, ΞΈ = f(position)                     β”‚
β”‚     q' = R(ΞΈ_m) Γ— q, k' = R(ΞΈ_n) Γ— k                           β”‚
β”‚     q' Β· k' = q Β· k Γ— cos(ΞΈ_m - ΞΈ_n)                           β”‚
β”‚     β†’ Naturally encodes relative positions                      β”‚
β”‚     β†’ Length extrapolation possible (with modifications)        β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

3.2 Mathematical Understanding

import torch
import math

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    """
    Precompute complex frequencies for RoPE

    Args:
        dim: Embedding dimension (head_dim)
        seq_len: Maximum sequence length
        theta: Base frequency (10000)

    Returns:
        freqs_cis: (seq_len, dim//2) complex tensor
    """
    # Frequency calculation: ΞΈ_i = 1 / (theta^(2i/d))
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    # Angle per position: m * ΞΈ_i
    t = torch.arange(seq_len)
    freqs = torch.outer(t, freqs)  # (seq_len, dim//2)
    # Complex form: e^(i*ΞΈ) = cos(ΞΈ) + i*sin(ΞΈ)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """
    Apply RoPE to Query and Key

    Args:
        xq: Query (batch, seq_len, n_heads, head_dim)
        xk: Key (batch, seq_len, n_kv_heads, head_dim)
        freqs_cis: Precomputed complex frequencies

    Returns:
        Rotated Query and Key
    """
    # Convert real to complex (pair adjacent elements)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # Apply rotation (complex multiplication)
    freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)  # (1, seq, 1, dim//2)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

# Example
batch, seq_len, n_heads, head_dim = 2, 128, 32, 128
xq = torch.randn(batch, seq_len, n_heads, head_dim)
xk = torch.randn(batch, seq_len, n_heads, head_dim)
freqs_cis = precompute_freqs_cis(head_dim, seq_len)

xq_rope, xk_rope = apply_rotary_emb(xq, xk, freqs_cis)
print(f"Output shape: {xq_rope.shape}")  # (2, 128, 32, 128)

3.3 Advantages of RoPE

"""
Advantages of RoPE:

1. Natural relative position encoding
   - q_m · k_n ∝ cos(θ_m - θ_n)
   - Depends on relative distance, not absolute position

2. Extrapolation capability
   - Can extend beyond trained lengths
   - (Performance degrades β†’ improved with NTK, YaRN)

3. Efficiency
   - No additional parameters
   - Fast element-wise operations

4. Compatible with linear Self-attention
   - Can combine with some efficient attention methods
"""

4. RMSNorm

4.1 Concept

RMSNorm is a simplified version of LayerNorm that removes mean computation.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    LayerNorm vs RMSNorm                          β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  LayerNorm:                                                     β”‚
β”‚  ──────────────────────────────────                             β”‚
β”‚  ΞΌ = mean(x)                                                    β”‚
β”‚  Οƒ = std(x)                                                     β”‚
β”‚  y = Ξ³ Γ— (x - ΞΌ) / Οƒ + Ξ²                                        β”‚
β”‚                                                                 β”‚
β”‚  β€’ Subtract mean + divide by variance                           β”‚
β”‚  β€’ Learnable scale(Ξ³) and shift(Ξ²)                              β”‚
β”‚                                                                 β”‚
β”‚  RMSNorm:                                                       β”‚
β”‚  ──────────────────────────────────                             β”‚
β”‚  RMS(x) = sqrt(mean(x^2))                                       β”‚
β”‚  y = Ξ³ Γ— x / RMS(x)                                             β”‚
β”‚                                                                 β”‚
β”‚  β€’ No mean subtraction β†’ Re-centering removed                   β”‚
β”‚  β€’ Scale by RMS only                                            β”‚
β”‚  β€’ No shift(Ξ²)                                                  β”‚
β”‚  β€’ Reduced computation, similar performance                     β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

4.2 Implementation

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization

    Paper: https://arxiv.org/abs/1910.07467
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # scale parameter Ξ³

    def _norm(self, x):
        # RMS = sqrt(mean(x^2))
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # output = Ξ³ Γ— (x / RMS(x))
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# Compare with LayerNorm
x = torch.randn(2, 10, 512)

layer_norm = nn.LayerNorm(512)
rms_norm = RMSNorm(512)

# Computation time comparison (RMSNorm is slightly faster)
import time

start = time.time()
for _ in range(1000):
    _ = layer_norm(x)
print(f"LayerNorm: {time.time() - start:.4f}s")

start = time.time()
for _ in range(1000):
    _ = rms_norm(x)
print(f"RMSNorm: {time.time() - start:.4f}s")

5. SwiGLU

5.1 Concept

SwiGLU is a variant of GLU (Gated Linear Unit) that uses the Swish activation function.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    FFN Activation Function Comparison            β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  1. ReLU FFN (Original Transformer):                            β”‚
β”‚     FFN(x) = max(0, xW₁ + b₁)Wβ‚‚ + bβ‚‚                            β”‚
β”‚     β€’ Simple but loses information in negative region           β”‚
β”‚                                                                 β”‚
β”‚  2. GELU FFN (BERT, GPT):                                       β”‚
β”‚     FFN(x) = GELU(xW₁)Wβ‚‚                                        β”‚
β”‚     GELU(x) = x Γ— Ξ¦(x)  (Ξ¦ = CDF of normal)                     β”‚
β”‚     β€’ Smooth activation, improved performance                   β”‚
β”‚                                                                 β”‚
β”‚  3. SwiGLU FFN (LLaMA):                                         β”‚
β”‚     FFN(x) = (Swish(xW₁) βŠ™ xV)Wβ‚‚                                β”‚
β”‚     Swish(x) = x Γ— Οƒ(x)  (Οƒ = sigmoid)                          β”‚
β”‚     βŠ™ = element-wise multiplication                             β”‚
β”‚                                                                 β”‚
β”‚     β€’ Gating mechanism controls information flow                β”‚
β”‚     β€’ More parameters, better performance                       β”‚
β”‚     β€’ 2/3 Γ— 4d hidden dim (maintains parameter count)           β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

5.2 Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    """
    SwiGLU: Swish-Gated Linear Unit

    FFN(x) = (Swish(xW₁) βŠ™ xV) Wβ‚‚

    Paper: https://arxiv.org/abs/2002.05202
    """
    def __init__(self, dim: int, hidden_dim: int = None, multiple_of: int = 256):
        super().__init__()

        # hidden_dim calculation: 2/3 Γ— 4d, rounded to multiple of 256
        if hidden_dim is None:
            hidden_dim = int(2 * (4 * dim) / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # gate
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # down projection
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # up projection

    def forward(self, x):
        # SwiGLU: Swish(xW₁) βŠ™ (xW₃) β†’ Wβ‚‚
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# Compare with standard FFN
class StandardFFN(nn.Module):
    def __init__(self, dim: int, hidden_dim: int = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.w2(F.gelu(self.w1(x)))

# Parameter count comparison
dim = 4096
swiglu = SwiGLU(dim)  # 3 linear layers: dim→hidden, dim→hidden, hidden→dim
standard = StandardFFN(dim)  # 2 linear layers: dim→4*dim, 4*dim→dim

print(f"SwiGLU params: {sum(p.numel() for p in swiglu.parameters()):,}")
print(f"Standard FFN params: {sum(p.numel() for p in standard.parameters()):,}")
# SwiGLU has slightly more but adjusted with hidden_dim to be similar

6. Grouped Query Attention (GQA)

6.1 Concept

GQA is an intermediate form between Multi-Head Attention and Multi-Query Attention.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    Attention Method Comparison                   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  1. Multi-Head Attention (MHA):                                 β”‚
β”‚     Q heads: 32  β”‚  K heads: 32  β”‚  V heads: 32                 β”‚
β”‚     β€’ Each Q head uses independent KV head                      β”‚
β”‚     β€’ Memory: High (32 Γ— KV cache)                              β”‚
β”‚                                                                 β”‚
β”‚  2. Multi-Query Attention (MQA):                                β”‚
β”‚     Q heads: 32  β”‚  K heads: 1   β”‚  V heads: 1                  β”‚
β”‚     β€’ All Q heads share same KV                                 β”‚
β”‚     β€’ Memory: Minimum (1 Γ— KV cache)                            β”‚
β”‚     β€’ Quality: Slightly lower than MHA                          β”‚
β”‚                                                                 β”‚
β”‚  3. Grouped Query Attention (GQA):                              β”‚
β”‚     Q heads: 32  β”‚  K heads: 8   β”‚  V heads: 8                  β”‚
β”‚     β€’ Q heads divided into groups sharing KV                    β”‚
β”‚     β€’ Example: 4 Q heads share 1 KV head                        β”‚
β”‚     β€’ Memory: Medium (8 Γ— KV cache)                             β”‚
β”‚     β€’ Quality: Nearly identical to MHA                          β”‚
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”               β”‚
β”‚  β”‚ MHA          β”‚ MQA           β”‚ GQA           β”‚               β”‚
β”‚  β”‚ Q Q Q Q Q Q  β”‚ Q Q Q Q Q Q   β”‚ Q Qβ”‚Q Qβ”‚Q Q   β”‚               β”‚
β”‚  β”‚ ↓ ↓ ↓ ↓ ↓ ↓  β”‚ ↓ ↓ ↓ ↓ ↓ ↓   β”‚ ↓ ↓│↓ ↓│↓ ↓   β”‚               β”‚
β”‚  β”‚ K K K K K K  β”‚     K         β”‚  K β”‚ K β”‚ K    β”‚               β”‚
β”‚  β”‚ V V V V V V  β”‚     V         β”‚  V β”‚ V β”‚ V    β”‚               β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜               β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

6.2 Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention (GQA)

    Paper: https://arxiv.org/abs/2305.13245
    """
    def __init__(
        self,
        dim: int,
        n_heads: int = 32,
        n_kv_heads: int = 8,  # Number of KV heads (< n_heads)
        head_dim: int = None,
    ):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim or dim // n_heads

        # Verify Q heads > KV heads
        assert n_heads % n_kv_heads == 0
        self.n_rep = n_heads // n_kv_heads  # Number of Q heads per KV head

        self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)

    def forward(self, x, freqs_cis=None, mask=None, kv_cache=None):
        batch, seq_len, _ = x.shape

        # Compute Q, K, V
        xq = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim)
        xk = self.wk(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
        xv = self.wv(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)

        # Apply RoPE (if available)
        if freqs_cis is not None:
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

        # KV Cache handling (during inference)
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            xk = torch.cat([cache_k, xk], dim=1)
            xv = torch.cat([cache_v, xv], dim=1)

        # Expand KV heads: n_kv_heads β†’ n_heads
        # (batch, seq, n_kv_heads, head_dim) β†’ (batch, seq, n_heads, head_dim)
        xk = self._repeat_kv(xk)
        xv = self._repeat_kv(xv)

        # Compute Attention
        xq = xq.transpose(1, 2)  # (batch, n_heads, seq, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores + mask

        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, xv)

        # Combine results
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(output), (xk, xv)

    def _repeat_kv(self, x):
        """Repeat KV heads to match Q heads count"""
        if self.n_rep == 1:
            return x
        batch, seq_len, n_kv_heads, head_dim = x.shape
        x = x[:, :, :, None, :].expand(batch, seq_len, n_kv_heads, self.n_rep, head_dim)
        return x.reshape(batch, seq_len, n_kv_heads * self.n_rep, head_dim)

# Memory usage comparison
def compare_kv_cache_memory(seq_len, batch_size=1, dtype_bytes=2):
    """Compare KV cache memory (FP16 basis)"""
    configs = {
        "LLaMA-70B (MHA)": {"n_layers": 80, "n_kv_heads": 64, "head_dim": 128},
        "LLaMA-70B (GQA)": {"n_layers": 80, "n_kv_heads": 8, "head_dim": 128},
    }

    for name, cfg in configs.items():
        kv_mem = (2 *  # K and V
                  batch_size *
                  cfg["n_layers"] *
                  seq_len *
                  cfg["n_kv_heads"] *
                  cfg["head_dim"] *
                  dtype_bytes)
        print(f"{name}: {kv_mem / 1e9:.2f} GB for {seq_len} tokens")

compare_kv_cache_memory(4096)
# LLaMA-70B (MHA): 5.24 GB for 4096 tokens
# LLaMA-70B (GQA): 0.66 GB for 4096 tokens  ← 8x savings!

7. LLaMA Practice

7.1 Using with HuggingFace

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load LLaMA 2 7B
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Text generation
def generate_text(prompt, max_new_tokens=100, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Usage example
prompt = "Explain the concept of machine learning in simple terms:"
response = generate_text(prompt)
print(response)

7.2 Efficient Usage with Quantization

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",          # NormalFloat4
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,     # Double quantization
)

# Load quantized model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto"
)

# Check memory usage
print(f"Model memory: {model.get_memory_footprint() / 1e9:.2f} GB")
# Approximately 4GB (75% savings compared to FP16)

7.3 Using LLaMA 3

# LLaMA 3 8B (128K tokenizer)
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # bfloat16 recommended for LLaMA 3
    device_map="auto"
)

# LLaMA 3 features:
# - 128K tokenizer (more efficient)
# - 8K base context (extendable to 128K)
# - Improved reasoning capability

prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

8. LLaMA 3.1/3.2 Details

8.1 LLaMA 3.1 (July 2024)

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    LLaMA 3.1 Key Features                        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  1. 128K Native Context                                         β”‚
β”‚     β€’ Supports 128K tokens from training                        β”‚
β”‚     β€’ Handles long context without RoPE scaling                 β”‚
β”‚                                                                 β”‚
β”‚  2. 405B Flagship Model                                         β”‚
β”‚     β€’ GPT-4 level performance                                   β”‚
β”‚     β€’ 126 layers, 16K embedding dimension                       β”‚
β”‚                                                                 β”‚
β”‚  3. Tool Use Capability                                         β”‚
β”‚     β€’ Function Calling                                          β”‚
β”‚     β€’ Code Interpreter                                          β”‚
β”‚     β€’ Search tool integration                                   β”‚
β”‚                                                                 β”‚
β”‚  4. Enhanced Multilingual Support                               β”‚
β”‚     β€’ English, German, French, Italian                          β”‚
β”‚     β€’ Portuguese, Hindi, Spanish, Thai                          β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
# LLaMA 3.1 Tool Use Example
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Tool Use format (LLaMA 3.1 specialized)
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_weather",
            "description": "Get current weather for a location",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {"type": "string", "description": "City name"}
                },
                "required": ["location"]
            }
        }
    }
]

messages = [
    {"role": "system", "content": "You are a helpful assistant with access to tools."},
    {"role": "user", "content": "What's the weather in Seoul?"}
]

# Generate tool call
inputs = tokenizer.apply_chat_template(
    messages,
    tools=tools,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))

8.2 LLaMA 3.2 (September 2024)

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    LLaMA 3.2 Model Lineup                        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  Lightweight Text Models (on-device optimized):                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                β”‚
β”‚  β”‚  LLaMA 3.2 1B: For mobile/edge devices      β”‚                β”‚
β”‚  β”‚  LLaMA 3.2 3B: For lightweight applications β”‚                β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                β”‚
β”‚                                                                 β”‚
β”‚  Vision Models (multimodal):                                    β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                β”‚
β”‚  β”‚  LLaMA 3.2 11B-Vision: Image understanding  β”‚                β”‚
β”‚  β”‚  LLaMA 3.2 90B-Vision: High-performance     β”‚                β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                β”‚
β”‚                                                                 β”‚
β”‚  Features:                                                       β”‚
β”‚  β€’ 1B/3B: 128K context, on-device inference capable             β”‚
β”‚  β€’ 11B/90B: Vision encoder integrated, image+text processing    β”‚
β”‚  β€’ Optimized for Qualcomm, MediaTek hardware                    β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
# LLaMA 3.2 Vision Example
from transformers import MllamaForConditionalGeneration, AutoProcessor
from PIL import Image
import requests

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

# Load image
url = "https://example.com/image.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# Vision conversation
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What do you see in this image?"}
        ]
    }
]

input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, return_tensors="pt").to(model.device)

output = model.generate(**inputs, max_new_tokens=256)
print(processor.decode(output[0]))

Summary

LLaMA Core Technologies

Technology Effect
RoPE Relative position encoding, length extension possible
RMSNorm Faster and more effective than LayerNorm
SwiGLU Better performance than GELU
GQA 8x KV cache memory savings

Practical Recommendations

  1. 7B/8B: Single GPU (16GB+), for quick experiments
  2. 13B: 24GB GPU, balanced choice
  3. 70B: Multiple GPUs, when top performance needed
  4. Quantization: 75% memory savings with 4-bit

Next Steps


References

Core Papers

  • Touvron et al. (2023). "LLaMA: Open and Efficient Foundation Language Models"
  • Touvron et al. (2023). "LLaMA 2: Open Foundation and Fine-Tuned Chat Models"
  • Su et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding"
  • Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models"

Code & Resources

to navigate between lessons