10. Long Context Models

10. Long Context Models

๊ฐœ์š”

ํ‘œ์ค€ Transformer์˜ Self-Attention์€ O(nยฒ) ๋ณต์žก๋„๋กœ ์ธํ•ด ๊ธด ์‹œํ€€์Šค ์ฒ˜๋ฆฌ์— ํ•œ๊ณ„๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ ˆ์Šจ์—์„œ๋Š” ์ปจํ…์ŠคํŠธ ๊ธธ์ด๋ฅผ ํ™•์žฅํ•˜๋Š” ๋‹ค์–‘ํ•œ ๊ธฐ๋ฒ•์„ ๋‹ค๋ฃน๋‹ˆ๋‹ค.


1. Context Length์˜ ์ค‘์š”์„ฑ

1.1 ์™œ ๊ธด ์ปจํ…์ŠคํŠธ๊ฐ€ ํ•„์š”ํ•œ๊ฐ€?

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                   Long Context ์‚ฌ์šฉ ์‚ฌ๋ก€                         โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                                  โ”‚
โ”‚  ๐Ÿ“š ๋ฌธ์„œ ๋ถ„์„                                                    โ”‚
โ”‚  - ๋…ผ๋ฌธ ์ „์ฒด (10K-50K ํ† ํฐ)                                     โ”‚
โ”‚  - ๋ฒ•๋ฅ  ๋ฌธ์„œ (100K+ ํ† ํฐ)                                       โ”‚
โ”‚  - ์ฑ… ์ „์ฒด ์š”์•ฝ                                                  โ”‚
โ”‚                                                                  โ”‚
โ”‚  ๐Ÿ’ป ์ฝ”๋“œ ์ดํ•ด                                                    โ”‚
โ”‚  - ์ „์ฒด ์ฝ”๋“œ๋ฒ ์ด์Šค ๋ถ„์„                                          โ”‚
โ”‚  - ๊ธด ํ•จ์ˆ˜/ํด๋ž˜์Šค ๋ฆฌํŒฉํ† ๋ง                                       โ”‚
โ”‚  - ๋ฉ€ํ‹ฐํŒŒ์ผ ๋””๋ฒ„๊น…                                               โ”‚
โ”‚                                                                  โ”‚
โ”‚  ๐Ÿค– Agent ์‹œ์Šคํ…œ                                                 โ”‚
โ”‚  - ๊ธด ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์œ ์ง€                                         โ”‚
โ”‚  - ๋ณต์žกํ•œ ๋ฉ€ํ‹ฐ์Šคํ… ํƒœ์Šคํฌ                                        โ”‚
โ”‚  - Tool ์‚ฌ์šฉ ๊ธฐ๋ก ๋ˆ„์                                            โ”‚
โ”‚                                                                  โ”‚
โ”‚  ๐Ÿ” RAG ๊ฐœ์„                                                      โ”‚
โ”‚  - ๋” ๋งŽ์€ ๊ด€๋ จ ๋ฌธ์„œ ํฌํ•จ                                        โ”‚
โ”‚  - ๋ฌธ์„œ ์กฐ๊ฐ ๋Œ€์‹  ์ „์ฒด ๋ฌธ์„œ ์ œ๊ณต                                 โ”‚
โ”‚                                                                  โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

1.2 ๋ชจ๋ธ๋ณ„ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ๋น„๊ต

๋ชจ๋ธ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ถœ์‹œ ์‹œ๊ธฐ
GPT-3 2,048 2020
GPT-3.5 4,096 / 16,384 2022-2023
GPT-4 8,192 / 32,768 / 128K 2023-2024
Claude 2 100,000 2023
Claude 3 200,000 2024
Gemini 1.5 1,000,000 / 2,000,000 2024
LLaMA 2 4,096 2023
LLaMA 3 8,192 / 128K 2024

2. ํšจ์œจ์ ์ธ Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜

2.1 Sparse Attention

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Sparse Attention ํŒจํ„ด                    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                             โ”‚
โ”‚  Local Attention        Global Attention                   โ”‚
โ”‚  โ–  โ–  โ–  โ–ก โ–ก โ–ก โ–ก         โ–  โ–ก โ–ก โ–ก โ–ก โ–ก โ–ก                      โ”‚
โ”‚  โ–  โ–  โ–  โ–  โ–ก โ–ก โ–ก         โ–  โ–  โ–ก โ–ก โ–ก โ–ก โ–ก                      โ”‚
โ”‚  โ–ก โ–  โ–  โ–  โ–  โ–ก โ–ก         โ–  โ–ก โ–  โ–ก โ–ก โ–ก โ–ก                      โ”‚
โ”‚  โ–ก โ–ก โ–  โ–  โ–  โ–  โ–ก         โ–  โ–ก โ–ก โ–  โ–ก โ–ก โ–ก                      โ”‚
โ”‚  โ–ก โ–ก โ–ก โ–  โ–  โ–  โ–          โ–  โ–ก โ–ก โ–ก โ–  โ–ก โ–ก                      โ”‚
โ”‚  โ–ก โ–ก โ–ก โ–ก โ–  โ–  โ–          โ–  โ–ก โ–ก โ–ก โ–ก โ–  โ–ก                      โ”‚
โ”‚  โ–ก โ–ก โ–ก โ–ก โ–ก โ–  โ–          โ–  โ–ก โ–ก โ–ก โ–ก โ–ก โ–                       โ”‚
โ”‚                                                             โ”‚
โ”‚  Longformer: Local + Global ํ† ํฐ ์กฐํ•ฉ                       โ”‚
โ”‚  BigBird: Local + Global + Random                          โ”‚
โ”‚                                                             โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2.2 Longformer ๊ตฌํ˜„

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LongformerAttention(nn.Module):
    """
    Longformer: Sliding Window + Global Attention

    ๋ณต์žก๋„: O(n ร— w) where w = window size
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        window_size: int = 256,
        global_tokens: int = 2  # [CLS], [SEP]
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.window_size = window_size
        self.global_tokens = global_tokens

        # Q, K, V projections
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        # Global attention์šฉ ๋ณ„๋„ projection
        self.global_query = nn.Linear(hidden_size, hidden_size)
        self.global_key = nn.Linear(hidden_size, hidden_size)
        self.global_value = nn.Linear(hidden_size, hidden_size)

        self.output = nn.Linear(hidden_size, hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: (batch, seq_len, hidden_size)
            attention_mask: (batch, seq_len)
        """
        batch_size, seq_len, _ = hidden_states.shape

        # Q, K, V ๊ณ„์‚ฐ
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)

        # Reshape: (batch, seq_len, num_heads, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # 1. Sliding Window Attention (local)
        local_output = self._sliding_window_attention(Q, K, V)

        # 2. Global Attention (์ฒ˜์Œ global_tokens๊ฐœ)
        global_output = self._global_attention(
            hidden_states, Q, K, V
        )

        # ๊ฒฐํ•ฉ (global ํ† ํฐ ์œ„์น˜๋Š” global ๊ฒฐ๊ณผ ์‚ฌ์šฉ)
        output = local_output.clone()
        output[:, :self.global_tokens] = global_output[:, :self.global_tokens]

        # Output projection
        output = output.view(batch_size, seq_len, self.hidden_size)
        output = self.output(output)

        return output

    def _sliding_window_attention(
        self,
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor
    ) -> torch.Tensor:
        """
        Sliding Window Attention

        ๊ฐ ํ† ํฐ์€ window_size ๋ฒ”์œ„ ๋‚ด์˜ ํ† ํฐ๋งŒ ์ฐธ์กฐ
        """
        batch_size, seq_len, num_heads, head_dim = Q.shape
        w = self.window_size // 2

        # ํŒจ๋”ฉ ์ถ”๊ฐ€
        Q_padded = F.pad(Q, (0, 0, 0, 0, w, w), value=0)
        K_padded = F.pad(K, (0, 0, 0, 0, w, w), value=0)
        V_padded = F.pad(V, (0, 0, 0, 0, w, w), value=0)

        # ์œˆ๋„์šฐ ์ถ”์ถœ (unfold)
        # ์‹ค์ œ ๊ตฌํ˜„์€ ๋” ๋ณต์žกํ•˜์ง€๋งŒ ๊ฐœ๋… ์ดํ•ด์šฉ ๊ฐ„์†Œํ™” ๋ฒ„์ „
        output = torch.zeros_like(Q)

        for i in range(seq_len):
            # i๋ฒˆ์งธ ํ† ํฐ์˜ ์œˆ๋„์šฐ: [i, i + window_size]
            start = i
            end = i + self.window_size

            q_i = Q[:, i:i+1]  # (batch, 1, heads, dim)
            k_window = K_padded[:, start:end]  # (batch, window, heads, dim)
            v_window = V_padded[:, start:end]

            # Attention
            scores = torch.einsum('bihd,bjhd->bijh', q_i, k_window)
            scores = scores / math.sqrt(head_dim)
            weights = F.softmax(scores, dim=2)
            out_i = torch.einsum('bijh,bjhd->bihd', weights, v_window)

            output[:, i] = out_i[:, 0]

        return output

    def _global_attention(
        self,
        hidden_states: torch.Tensor,
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor
    ) -> torch.Tensor:
        """Global Attention: global ํ† ํฐ์€ ์ „์ฒด ์‹œํ€€์Šค ์ฐธ์กฐ"""
        batch_size, seq_len, _ = hidden_states.shape

        # Global ํ† ํฐ๋งŒ ์ถ”์ถœ
        global_hidden = hidden_states[:, :self.global_tokens]

        # Global Q, K, V
        global_Q = self.global_query(global_hidden)
        global_K = self.global_key(hidden_states)
        global_V = self.global_value(hidden_states)

        # ์ „์ฒด ์‹œํ€€์Šค์— ๋Œ€ํ•ด attention
        global_Q = global_Q.view(batch_size, self.global_tokens,
                                  self.num_heads, self.head_dim)
        global_K = global_K.view(batch_size, seq_len,
                                  self.num_heads, self.head_dim)
        global_V = global_V.view(batch_size, seq_len,
                                  self.num_heads, self.head_dim)

        # (batch, global, heads, seq) attention
        scores = torch.einsum('bghd,bshd->bghs', global_Q, global_K)
        scores = scores / math.sqrt(self.head_dim)
        weights = F.softmax(scores, dim=-1)

        # Output: (batch, global, heads, dim)
        output = torch.einsum('bghs,bshd->bghd', weights, global_V)

        return output

2.3 Flash Attention

# Flash Attention์€ CUDA ์ปค๋„๋กœ ๊ตฌํ˜„๋˜์–ด ์žˆ์Œ
# ์—ฌ๊ธฐ์„œ๋Š” ๊ฐœ๋…๋งŒ ์„ค๋ช…

"""
Flash Attention ํ•ต์‹ฌ ์•„์ด๋””์–ด:

1. ํƒ€์ผ๋ง (Tiling):
   - Q, K, V๋ฅผ SRAM์— ๋งž๋Š” ๋ธ”๋ก์œผ๋กœ ๋ถ„ํ• 
   - HBM โ†” SRAM ๋ฐ์ดํ„ฐ ์ „์†ก ์ตœ์†Œํ™”

2. ์žฌ๊ณ„์‚ฐ (Recomputation):
   - Forward์—์„œ attention weights ์ €์žฅ ์•ˆ ํ•จ
   - Backward์—์„œ ํ•„์š”ํ•  ๋•Œ ์žฌ๊ณ„์‚ฐ
   - ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ (O(n) vs O(nยฒ))

3. ๊ฒฐ๊ณผ:
   - ๋ฉ”๋ชจ๋ฆฌ: O(n) vs O(nยฒ)
   - ์†๋„: 2-4x ๋น ๋ฆ„
   - ์ •ํ™•๋„: ์ˆ˜์น˜์ ์œผ๋กœ ๋™์ผ
"""

# PyTorch 2.0+์—์„œ ์‚ฌ์šฉ
def use_flash_attention():
    import torch.nn.functional as F

    # Scaled Dot-Product Attention (Flash Attention ์ž๋™ ์‚ฌ์šฉ)
    Q = torch.randn(2, 8, 1024, 64, device='cuda')
    K = torch.randn(2, 8, 1024, 64, device='cuda')
    V = torch.randn(2, 8, 1024, 64, device='cuda')

    # PyTorch 2.0+ SDPA
    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,
        enable_mem_efficient=False
    ):
        output = F.scaled_dot_product_attention(Q, K, V)

    return output


# xFormers ์‚ฌ์šฉ
def use_xformers():
    from xformers.ops import memory_efficient_attention

    Q = torch.randn(2, 1024, 8, 64, device='cuda')
    K = torch.randn(2, 1024, 8, 64, device='cuda')
    V = torch.randn(2, 1024, 8, 64, device='cuda')

    output = memory_efficient_attention(Q, K, V)
    return output

3. ์œ„์น˜ ์ธ์ฝ”๋”ฉ ํ™•์žฅ

3.1 ๋ฌธ์ œ: ํ•™์Šต ๊ธธ์ด๋ฅผ ๋„˜์–ด์„œ ์™ธ์‚ฝ

ํ•™์Šต: 4096 ํ† ํฐ
์ถ”๋ก : 8192+ ํ† ํฐ

๋ฌธ์ œ:
- ์ ˆ๋Œ€ ์œ„์น˜ ์ธ์ฝ”๋”ฉ: 4096 ์ดํ›„ ์œ„์น˜ ํ•™์Šต ์•ˆ ๋จ
- RoPE: ๋ณด๊ฐ„/์™ธ์‚ฝ ํ•„์š”

3.2 Position Interpolation (PI)

def linear_position_interpolation(
    position_ids: torch.Tensor,
    original_max_length: int,
    extended_max_length: int
) -> torch.Tensor:
    """
    Linear Position Interpolation

    ์•„์ด๋””์–ด: ์ƒˆ ์œ„์น˜๋ฅผ ์›๋ณธ ๋ฒ”์œ„๋กœ ์Šค์ผ€์ผ๋ง

    position_ids๋ฅผ [0, original_max_length)๋กœ ์••์ถ•
    """
    scale = original_max_length / extended_max_length
    return position_ids.float() * scale


class RoPEWithInterpolation(nn.Module):
    """Position Interpolation์ด ์ ์šฉ๋œ RoPE"""

    def __init__(
        self,
        dim: int,
        original_max_length: int = 4096,
        extended_max_length: int = 16384,
        base: float = 10000.0
    ):
        super().__init__()
        self.dim = dim
        self.original_max_length = original_max_length
        self.extended_max_length = extended_max_length
        self.base = base

        # ์ฃผํŒŒ์ˆ˜ ๊ณ„์‚ฐ
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # ์Šค์ผ€์ผ ํŒฉํ„ฐ
        self.scale = original_max_length / extended_max_length

    def forward(
        self,
        x: torch.Tensor,
        position_ids: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, heads, dim)
            position_ids: (batch, seq_len)
        """
        # ์œ„์น˜ ๋ณด๊ฐ„
        scaled_positions = position_ids.float() * self.scale

        # ์ฃผํŒŒ์ˆ˜ ๊ณ„์‚ฐ
        freqs = torch.einsum('bi,d->bid', scaled_positions, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)

        cos = emb.cos().unsqueeze(2)  # (batch, seq, 1, dim)
        sin = emb.sin().unsqueeze(2)

        # RoPE ์ ์šฉ
        x_rope = self._apply_rope(x, cos, sin)

        return x_rope

    def _apply_rope(self, x, cos, sin):
        """RoPE ์ ์šฉ"""
        x1 = x[..., : self.dim // 2]
        x2 = x[..., self.dim // 2 :]

        rotated = torch.cat([-x2, x1], dim=-1)
        return x * cos + rotated * sin

3.3 YaRN (Yet another RoPE extension method)

class YaRNRoPE(nn.Module):
    """
    YaRN: NTK-aware Interpolation

    Position Interpolation์˜ ๋ฌธ์ œ:
    - ๊ณ ์ฃผํŒŒ ์ •๋ณด ์†์‹ค (๋†’์€ ์ฐจ์›)

    YaRN ํ•ด๊ฒฐ์ฑ…:
    - ์ €์ฃผํŒŒ: ๋ณด๊ฐ„ (interpolation)
    - ๊ณ ์ฃผํŒŒ: ์™ธ์‚ฝ (extrapolation)
    - NTK ์Šค์ผ€์ผ๋ง์œผ๋กœ ์ฃผํŒŒ์ˆ˜ ์กฐ์ •
    """

    def __init__(
        self,
        dim: int,
        original_max_length: int = 4096,
        extended_max_length: int = 32768,
        base: float = 10000.0,
        beta_fast: float = 32,
        beta_slow: float = 1,
    ):
        super().__init__()
        self.dim = dim
        self.original_max_length = original_max_length
        self.extended_max_length = extended_max_length

        scale = extended_max_length / original_max_length

        # ์ฐจ์›๋ณ„ ๋ณด๊ฐ„ ๋น„์œจ ๊ณ„์‚ฐ
        # ์ €์ฃผํŒŒ (๋‚ฎ์€ ์ฐจ์›): ๋” ๋งŽ์ด ๋ณด๊ฐ„
        # ๊ณ ์ฃผํŒŒ (๋†’์€ ์ฐจ์›): ๋œ ๋ณด๊ฐ„ (์™ธ์‚ฝ์— ๊ฐ€๊นŒ์›€)
        dims = torch.arange(0, dim, 2)
        low = max(0, math.floor(dim * math.log(scale) / (2 * math.log(original_max_length))))
        high = min(dim // 2 - 1, math.ceil(dim * math.log(scale) / (2 * math.log(original_max_length))))

        # ๋žจํ”„ ํ•จ์ˆ˜๋กœ ๋ณด๊ฐ„/์™ธ์‚ฝ ๋น„์œจ ๊ฒฐ์ •
        ramp = torch.zeros(dim // 2)
        ramp[:low] = 0.0  # ์™ธ์‚ฝ
        ramp[high:] = 1.0  # ๋ณด๊ฐ„

        if high > low:
            ramp[low:high] = (dims[low:high] - low) / (high - low)

        # NTK-aware base ์กฐ์ •
        inv_freq = 1.0 / (base ** (dims.float() / dim))

        # ๋ณด๊ฐ„๊ณผ ์™ธ์‚ฝ์˜ ํ˜ผํ•ฉ
        inv_freq_inter = inv_freq / scale
        self.register_buffer(
            'inv_freq',
            (1 - ramp) * inv_freq + ramp * inv_freq_inter
        )

        # Attention scaling
        self.mscale = 0.1 * math.log(scale) + 1.0

    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        # ์ฃผํŒŒ์ˆ˜ ๊ณ„์‚ฐ (์ด๋ฏธ ์กฐ์ •๋œ inv_freq ์‚ฌ์šฉ)
        freqs = torch.einsum('bi,d->bid', position_ids.float(), self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)

        cos = emb.cos().unsqueeze(2) * self.mscale
        sin = emb.sin().unsqueeze(2) * self.mscale

        return self._apply_rope(x, cos, sin)

4. ALiBi (Attention with Linear Biases)

4.1 ๊ฐœ๋…

ALiBi: ํ•™์Šต ์—†๋Š” ์œ„์น˜ ์ธ์ฝ”๋”ฉ

์•„์ด๋””์–ด:
- ์œ„์น˜ ์ธ์ฝ”๋”ฉ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ
- ๋Œ€์‹ , attention ์ ์ˆ˜์— ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ ํŒจ๋„ํ‹ฐ ์ถ”๊ฐ€
- ๋ฉ€๋ฆฌ ์žˆ๋Š” ํ† ํฐ์ผ์ˆ˜๋ก attention ์ ์ˆ˜ ๊ฐ์†Œ

Attention score modification:
score(q_i, k_j) = q_i ยท k_j - m ร— |i - j|

m: head๋ณ„ ๊ธฐ์šธ๊ธฐ (๊ณ ์ •, ํ•™์Šต ์•ˆ ํ•จ)
m_h = 2^(-8/H) for head h (H = ์ด head ์ˆ˜)

4.2 ๊ตฌํ˜„

class ALiBiAttention(nn.Module):
    """ALiBi: Attention with Linear Biases"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_seq_len: int = 8192
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, hidden_size)

        # ALiBi slopes: ๊ธฐํ•˜๊ธ‰์ˆ˜์ ์œผ๋กœ ๊ฐ์†Œ
        # 2^(-8/n), 2^(-8*2/n), ..., 2^(-8)
        slopes = self._get_alibi_slopes(num_heads)
        self.register_buffer('slopes', slopes)

        # ๊ฑฐ๋ฆฌ ํ–‰๋ ฌ ์‚ฌ์ „ ๊ณ„์‚ฐ
        positions = torch.arange(max_seq_len)
        distance_matrix = positions.unsqueeze(0) - positions.unsqueeze(1)
        distance_matrix = distance_matrix.abs()
        self.register_buffer('distance_matrix', distance_matrix)

    def _get_alibi_slopes(self, num_heads: int) -> torch.Tensor:
        """Head๋ณ„ ALiBi slope ๊ณ„์‚ฐ"""

        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(num_heads).is_integer():
            slopes = get_slopes_power_of_2(num_heads)
        else:
            # ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด 2์˜ ๊ฑฐ๋“ญ์ œ๊ณฑ์œผ๋กœ ๋ณด๊ฐ„
            closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
            slopes = get_slopes_power_of_2(closest_power_of_2)

            extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)
            extra_slopes = extra_slopes[0::2][:num_heads - closest_power_of_2]
            slopes = slopes + extra_slopes

        return torch.tensor(slopes).view(1, num_heads, 1, 1)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape

        # Q, K, V ๊ณ„์‚ฐ
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)

        # Reshape
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose: (batch, heads, seq, dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # ALiBi bias: -m ร— |i - j|
        alibi_bias = -self.slopes * self.distance_matrix[:seq_len, :seq_len]
        scores = scores + alibi_bias

        # Causal mask
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=scores.device) * float('-inf'),
            diagonal=1
        )
        scores = scores + causal_mask

        # Attention weights
        weights = F.softmax(scores, dim=-1)

        # Output
        output = torch.matmul(weights, V)
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, self.hidden_size)
        output = self.output(output)

        return output

5. Ring Attention

5.1 ๊ฐœ๋…

Ring Attention: ๋ถ„์‚ฐ Long Context

์•„์ด๋””์–ด:
- ์‹œํ€€์Šค๋ฅผ ์—ฌ๋Ÿฌ GPU์— ๋ถ„์‚ฐ
- ๊ฐ GPU๊ฐ€ ์ž์‹ ์˜ ์ฒญํฌ + ์ˆœํ™˜ํ•˜๋Š” KV ์ฒ˜๋ฆฌ
- ํ†ต์‹ ๊ณผ ๊ณ„์‚ฐ ์˜ค๋ฒ„๋žฉ

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                Ring Attention                   โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                โ”‚
โ”‚  GPU 0: Q[0:n/4]     GPU 1: Q[n/4:n/2]        โ”‚
โ”‚          โ†“ KV ์ˆœํ™˜        โ†“ KV ์ˆœํ™˜            โ”‚
โ”‚  Step 1: K[0:n/4]    Step 1: K[n/4:n/2]       โ”‚
โ”‚  Step 2: K[n/4:n/2]  Step 2: K[n/2:3n/4]      โ”‚
โ”‚  Step 3: K[n/2:3n/4] Step 3: K[3n/4:n]        โ”‚
โ”‚  Step 4: K[3n/4:n]   Step 4: K[0:n/4]         โ”‚
โ”‚                                                โ”‚
โ”‚  KV๊ฐ€ ๋ง์ฒ˜๋Ÿผ ์ˆœํ™˜ํ•˜๋ฉฐ ๊ฐ GPU์˜ Q์™€ ๊ฒฐํ•ฉ         โ”‚
โ”‚                                                โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

5.2 ๊ตฌํ˜„ ๊ฐœ์š”

import torch.distributed as dist

def ring_attention_forward(
    Q: torch.Tensor,  # Local Q chunk
    K: torch.Tensor,  # Local K chunk
    V: torch.Tensor,  # Local V chunk
    world_size: int,
    rank: int
):
    """
    Ring Attention Forward Pass (๊ฐœ๋…์  ๊ตฌํ˜„)

    ์‹ค์ œ ๊ตฌํ˜„์€ CUDA ์ปค๋„๊ณผ ๋ณต์žกํ•œ ๋™๊ธฐํ™” ํ•„์š”
    """
    local_seq_len = Q.shape[1]

    # ๋ˆ„์  attention ์ถœ๋ ฅ
    output = torch.zeros_like(Q)
    max_scores = torch.full((Q.shape[0], Q.shape[2], local_seq_len), float('-inf'))
    sum_exp = torch.zeros_like(max_scores)

    # ํ˜„์žฌ KV
    current_K = K.clone()
    current_V = V.clone()

    for step in range(world_size):
        # ์ด ์ฒญํฌ์˜ KV์— ๋Œ€ํ•ด attention ๊ณ„์‚ฐ
        scores = torch.matmul(Q, current_K.transpose(-2, -1))
        scores = scores / math.sqrt(Q.shape[-1])

        # Online softmax (numerically stable)
        new_max = torch.max(scores.max(dim=-1).values, max_scores)
        exp_scores = torch.exp(scores - new_max.unsqueeze(-1))

        # ์ด์ „ ๊ฒฐ๊ณผ ์Šค์ผ€์ผ ์กฐ์ •
        scale = torch.exp(max_scores - new_max)
        output = output * scale.unsqueeze(-1) + torch.matmul(exp_scores, current_V)

        sum_exp = sum_exp * scale + exp_scores.sum(dim=-1)
        max_scores = new_max

        # KV๋ฅผ ๋‹ค์Œ GPU๋กœ ์ „์†ก (ring)
        if step < world_size - 1:
            # ๋น„๋™๊ธฐ send/recv
            send_rank = (rank + 1) % world_size
            recv_rank = (rank - 1) % world_size

            # ๋‹ค์Œ GPU์—์„œ KV ์ˆ˜์‹ 
            current_K = ring_pass(current_K, send_rank, recv_rank)
            current_V = ring_pass(current_V, send_rank, recv_rank)

    # ์ตœ์ข… ์ •๊ทœํ™”
    output = output / sum_exp.unsqueeze(-1)

    return output


def ring_pass(tensor, send_rank, recv_rank):
    """Ring topology์—์„œ ํ…์„œ ์ „๋‹ฌ"""
    recv_tensor = torch.empty_like(tensor)

    send_op = dist.isend(tensor, send_rank)
    recv_op = dist.irecv(recv_tensor, recv_rank)

    send_op.wait()
    recv_op.wait()

    return recv_tensor

6. ์‹ค์šฉ์  ๊ฐ€์ด๋“œ

6.1 ์ปจํ…์ŠคํŠธ ํ™•์žฅ ๋ฐฉ๋ฒ• ์„ ํƒ

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              ์–ธ์ œ ์–ด๋–ค ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ• ๊นŒ?                       โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                              โ”‚
โ”‚  4K โ†’ 8K:                                                    โ”‚
โ”‚  - Position Interpolation (๊ฐ„๋‹จ, ์„ฑ๋Šฅ ์ข‹์Œ)                  โ”‚
โ”‚  - ์•ฝ๊ฐ„์˜ fine-tuning ๊ถŒ์žฅ                                   โ”‚
โ”‚                                                              โ”‚
โ”‚  4K โ†’ 32K:                                                   โ”‚
โ”‚  - YaRN (PI๋ณด๋‹ค ์„ฑ๋Šฅ ์ข‹์Œ)                                   โ”‚
โ”‚  - ๋˜๋Š” ALiBi (์ฒ˜์Œ๋ถ€ํ„ฐ ํ•™์Šต ์‹œ)                             โ”‚
โ”‚                                                              โ”‚
โ”‚  32K โ†’ 100K+:                                                โ”‚
โ”‚  - Flash Attention ํ•„์ˆ˜                                      โ”‚
โ”‚  - Ring Attention (๋‹ค์ค‘ GPU)                                 โ”‚
โ”‚  - Sparse Attention ๊ณ ๋ ค                                     โ”‚
โ”‚                                                              โ”‚
โ”‚  1M+:                                                        โ”‚
โ”‚  - ํŠน์ˆ˜ ์•„ํ‚คํ…์ฒ˜ ํ•„์š”                                        โ”‚
โ”‚  - Mamba/State Space Models                                  โ”‚
โ”‚  - ๋˜๋Š” ๊ทน๋„๋กœ ํฌ์†Œํ•œ attention                              โ”‚
โ”‚                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

6.2 ์‹ค์ „ ํŒ

# 1. Gradient Checkpointing์€ ํ•„์ˆ˜
model.gradient_checkpointing_enable()

# 2. Mixed Precision ์‚ฌ์šฉ
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    outputs = model(**inputs)

# 3. KV Cache ์ตœ์ ํ™” (์ถ”๋ก  ์‹œ)
# - Sliding Window Cache
# - Paged Attention (vLLM)

# 4. ์ฒญํฌ ๋‹จ์œ„ ์ฒ˜๋ฆฌ (๊ธด ๋ฌธ์„œ)
def process_long_document(model, document, chunk_size=4096, overlap=512):
    """๊ธด ๋ฌธ์„œ๋ฅผ ์ฒญํฌ๋กœ ๋‚˜๋ˆ  ์ฒ˜๋ฆฌ"""
    tokens = tokenizer.encode(document)
    results = []

    for i in range(0, len(tokens), chunk_size - overlap):
        chunk = tokens[i:i + chunk_size]
        output = model.generate(chunk)
        results.append(output)

    return merge_results(results)

์ฐธ๊ณ  ์ž๋ฃŒ

๋…ผ๋ฌธ

  • Beltagy et al. (2020). "Longformer: The Long-Document Transformer"
  • Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention"
  • Press et al. (2021). "Train Short, Test Long: Attention with Linear Biases"
  • Peng et al. (2023). "YaRN: Efficient Context Window Extension of Large Language Models"

๊ด€๋ จ ๋ ˆ์Šจ

to navigate between lessons