10. Long Context Models
10. Long Context Models¶
๊ฐ์¶
ํ์ค Transformer์ Self-Attention์ O(nยฒ) ๋ณต์ก๋๋ก ์ธํด ๊ธด ์ํ์ค ์ฒ๋ฆฌ์ ํ๊ณ๊ฐ ์์ต๋๋ค. ์ด ๋ ์จ์์๋ ์ปจํ ์คํธ ๊ธธ์ด๋ฅผ ํ์ฅํ๋ ๋ค์ํ ๊ธฐ๋ฒ์ ๋ค๋ฃน๋๋ค.
1. Context Length์ ์ค์์ฑ¶
1.1 ์ ๊ธด ์ปจํ ์คํธ๊ฐ ํ์ํ๊ฐ?¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Long Context ์ฌ์ฉ ์ฌ๋ก โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ ๐ ๋ฌธ์ ๋ถ์ โ
โ - ๋
ผ๋ฌธ ์ ์ฒด (10K-50K ํ ํฐ) โ
โ - ๋ฒ๋ฅ ๋ฌธ์ (100K+ ํ ํฐ) โ
โ - ์ฑ
์ ์ฒด ์์ฝ โ
โ โ
โ ๐ป ์ฝ๋ ์ดํด โ
โ - ์ ์ฒด ์ฝ๋๋ฒ ์ด์ค ๋ถ์ โ
โ - ๊ธด ํจ์/ํด๋์ค ๋ฆฌํฉํ ๋ง โ
โ - ๋ฉํฐํ์ผ ๋๋ฒ๊น
โ
โ โ
โ ๐ค Agent ์์คํ
โ
โ - ๊ธด ๋ํ ํ์คํ ๋ฆฌ ์ ์ง โ
โ - ๋ณต์กํ ๋ฉํฐ์คํ
ํ์คํฌ โ
โ - Tool ์ฌ์ฉ ๊ธฐ๋ก ๋์ โ
โ โ
โ ๐ RAG ๊ฐ์ โ
โ - ๋ ๋ง์ ๊ด๋ จ ๋ฌธ์ ํฌํจ โ
โ - ๋ฌธ์ ์กฐ๊ฐ ๋์ ์ ์ฒด ๋ฌธ์ ์ ๊ณต โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1.2 ๋ชจ๋ธ๋ณ ์ปจํ ์คํธ ๊ธธ์ด ๋น๊ต¶
| ๋ชจ๋ธ | ์ปจํ ์คํธ ๊ธธ์ด | ์ถ์ ์๊ธฐ |
|---|---|---|
| GPT-3 | 2,048 | 2020 |
| GPT-3.5 | 4,096 / 16,384 | 2022-2023 |
| GPT-4 | 8,192 / 32,768 / 128K | 2023-2024 |
| Claude 2 | 100,000 | 2023 |
| Claude 3 | 200,000 | 2024 |
| Gemini 1.5 | 1,000,000 / 2,000,000 | 2024 |
| LLaMA 2 | 4,096 | 2023 |
| LLaMA 3 | 8,192 / 128K | 2024 |
2. ํจ์จ์ ์ธ Attention ๋ฉ์ปค๋์ฆ¶
2.1 Sparse Attention¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Sparse Attention ํจํด โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Local Attention Global Attention โ
โ โ โ โ โก โก โก โก โ โก โก โก โก โก โก โ
โ โ โ โ โ โก โก โก โ โ โก โก โก โก โก โ
โ โก โ โ โ โ โก โก โ โก โ โก โก โก โก โ
โ โก โก โ โ โ โ โก โ โก โก โ โก โก โก โ
โ โก โก โก โ โ โ โ โ โก โก โก โ โก โก โ
โ โก โก โก โก โ โ โ โ โก โก โก โก โ โก โ
โ โก โก โก โก โก โ โ โ โก โก โก โก โก โ โ
โ โ
โ Longformer: Local + Global ํ ํฐ ์กฐํฉ โ
โ BigBird: Local + Global + Random โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
2.2 Longformer ๊ตฌํ¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LongformerAttention(nn.Module):
"""
Longformer: Sliding Window + Global Attention
๋ณต์ก๋: O(n ร w) where w = window size
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
window_size: int = 256,
global_tokens: int = 2 # [CLS], [SEP]
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.window_size = window_size
self.global_tokens = global_tokens
# Q, K, V projections
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
# Global attention์ฉ ๋ณ๋ projection
self.global_query = nn.Linear(hidden_size, hidden_size)
self.global_key = nn.Linear(hidden_size, hidden_size)
self.global_value = nn.Linear(hidden_size, hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
Args:
hidden_states: (batch, seq_len, hidden_size)
attention_mask: (batch, seq_len)
"""
batch_size, seq_len, _ = hidden_states.shape
# Q, K, V ๊ณ์ฐ
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
# Reshape: (batch, seq_len, num_heads, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
# 1. Sliding Window Attention (local)
local_output = self._sliding_window_attention(Q, K, V)
# 2. Global Attention (์ฒ์ global_tokens๊ฐ)
global_output = self._global_attention(
hidden_states, Q, K, V
)
# ๊ฒฐํฉ (global ํ ํฐ ์์น๋ global ๊ฒฐ๊ณผ ์ฌ์ฉ)
output = local_output.clone()
output[:, :self.global_tokens] = global_output[:, :self.global_tokens]
# Output projection
output = output.view(batch_size, seq_len, self.hidden_size)
output = self.output(output)
return output
def _sliding_window_attention(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor
) -> torch.Tensor:
"""
Sliding Window Attention
๊ฐ ํ ํฐ์ window_size ๋ฒ์ ๋ด์ ํ ํฐ๋ง ์ฐธ์กฐ
"""
batch_size, seq_len, num_heads, head_dim = Q.shape
w = self.window_size // 2
# ํจ๋ฉ ์ถ๊ฐ
Q_padded = F.pad(Q, (0, 0, 0, 0, w, w), value=0)
K_padded = F.pad(K, (0, 0, 0, 0, w, w), value=0)
V_padded = F.pad(V, (0, 0, 0, 0, w, w), value=0)
# ์๋์ฐ ์ถ์ถ (unfold)
# ์ค์ ๊ตฌํ์ ๋ ๋ณต์กํ์ง๋ง ๊ฐ๋
์ดํด์ฉ ๊ฐ์ํ ๋ฒ์
output = torch.zeros_like(Q)
for i in range(seq_len):
# i๋ฒ์งธ ํ ํฐ์ ์๋์ฐ: [i, i + window_size]
start = i
end = i + self.window_size
q_i = Q[:, i:i+1] # (batch, 1, heads, dim)
k_window = K_padded[:, start:end] # (batch, window, heads, dim)
v_window = V_padded[:, start:end]
# Attention
scores = torch.einsum('bihd,bjhd->bijh', q_i, k_window)
scores = scores / math.sqrt(head_dim)
weights = F.softmax(scores, dim=2)
out_i = torch.einsum('bijh,bjhd->bihd', weights, v_window)
output[:, i] = out_i[:, 0]
return output
def _global_attention(
self,
hidden_states: torch.Tensor,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor
) -> torch.Tensor:
"""Global Attention: global ํ ํฐ์ ์ ์ฒด ์ํ์ค ์ฐธ์กฐ"""
batch_size, seq_len, _ = hidden_states.shape
# Global ํ ํฐ๋ง ์ถ์ถ
global_hidden = hidden_states[:, :self.global_tokens]
# Global Q, K, V
global_Q = self.global_query(global_hidden)
global_K = self.global_key(hidden_states)
global_V = self.global_value(hidden_states)
# ์ ์ฒด ์ํ์ค์ ๋ํด attention
global_Q = global_Q.view(batch_size, self.global_tokens,
self.num_heads, self.head_dim)
global_K = global_K.view(batch_size, seq_len,
self.num_heads, self.head_dim)
global_V = global_V.view(batch_size, seq_len,
self.num_heads, self.head_dim)
# (batch, global, heads, seq) attention
scores = torch.einsum('bghd,bshd->bghs', global_Q, global_K)
scores = scores / math.sqrt(self.head_dim)
weights = F.softmax(scores, dim=-1)
# Output: (batch, global, heads, dim)
output = torch.einsum('bghs,bshd->bghd', weights, global_V)
return output
2.3 Flash Attention¶
# Flash Attention์ CUDA ์ปค๋๋ก ๊ตฌํ๋์ด ์์
# ์ฌ๊ธฐ์๋ ๊ฐ๋
๋ง ์ค๋ช
"""
Flash Attention ํต์ฌ ์์ด๋์ด:
1. ํ์ผ๋ง (Tiling):
- Q, K, V๋ฅผ SRAM์ ๋ง๋ ๋ธ๋ก์ผ๋ก ๋ถํ
- HBM โ SRAM ๋ฐ์ดํฐ ์ ์ก ์ต์ํ
2. ์ฌ๊ณ์ฐ (Recomputation):
- Forward์์ attention weights ์ ์ฅ ์ ํจ
- Backward์์ ํ์ํ ๋ ์ฌ๊ณ์ฐ
- ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ (O(n) vs O(nยฒ))
3. ๊ฒฐ๊ณผ:
- ๋ฉ๋ชจ๋ฆฌ: O(n) vs O(nยฒ)
- ์๋: 2-4x ๋น ๋ฆ
- ์ ํ๋: ์์น์ ์ผ๋ก ๋์ผ
"""
# PyTorch 2.0+์์ ์ฌ์ฉ
def use_flash_attention():
import torch.nn.functional as F
# Scaled Dot-Product Attention (Flash Attention ์๋ ์ฌ์ฉ)
Q = torch.randn(2, 8, 1024, 64, device='cuda')
K = torch.randn(2, 8, 1024, 64, device='cuda')
V = torch.randn(2, 8, 1024, 64, device='cuda')
# PyTorch 2.0+ SDPA
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(Q, K, V)
return output
# xFormers ์ฌ์ฉ
def use_xformers():
from xformers.ops import memory_efficient_attention
Q = torch.randn(2, 1024, 8, 64, device='cuda')
K = torch.randn(2, 1024, 8, 64, device='cuda')
V = torch.randn(2, 1024, 8, 64, device='cuda')
output = memory_efficient_attention(Q, K, V)
return output
3. ์์น ์ธ์ฝ๋ฉ ํ์ฅ¶
3.1 ๋ฌธ์ : ํ์ต ๊ธธ์ด๋ฅผ ๋์ด์ ์ธ์ฝ¶
ํ์ต: 4096 ํ ํฐ
์ถ๋ก : 8192+ ํ ํฐ
๋ฌธ์ :
- ์ ๋ ์์น ์ธ์ฝ๋ฉ: 4096 ์ดํ ์์น ํ์ต ์ ๋จ
- RoPE: ๋ณด๊ฐ/์ธ์ฝ ํ์
3.2 Position Interpolation (PI)¶
def linear_position_interpolation(
position_ids: torch.Tensor,
original_max_length: int,
extended_max_length: int
) -> torch.Tensor:
"""
Linear Position Interpolation
์์ด๋์ด: ์ ์์น๋ฅผ ์๋ณธ ๋ฒ์๋ก ์ค์ผ์ผ๋ง
position_ids๋ฅผ [0, original_max_length)๋ก ์์ถ
"""
scale = original_max_length / extended_max_length
return position_ids.float() * scale
class RoPEWithInterpolation(nn.Module):
"""Position Interpolation์ด ์ ์ฉ๋ RoPE"""
def __init__(
self,
dim: int,
original_max_length: int = 4096,
extended_max_length: int = 16384,
base: float = 10000.0
):
super().__init__()
self.dim = dim
self.original_max_length = original_max_length
self.extended_max_length = extended_max_length
self.base = base
# ์ฃผํ์ ๊ณ์ฐ
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# ์ค์ผ์ผ ํฉํฐ
self.scale = original_max_length / extended_max_length
def forward(
self,
x: torch.Tensor,
position_ids: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, heads, dim)
position_ids: (batch, seq_len)
"""
# ์์น ๋ณด๊ฐ
scaled_positions = position_ids.float() * self.scale
# ์ฃผํ์ ๊ณ์ฐ
freqs = torch.einsum('bi,d->bid', scaled_positions, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().unsqueeze(2) # (batch, seq, 1, dim)
sin = emb.sin().unsqueeze(2)
# RoPE ์ ์ฉ
x_rope = self._apply_rope(x, cos, sin)
return x_rope
def _apply_rope(self, x, cos, sin):
"""RoPE ์ ์ฉ"""
x1 = x[..., : self.dim // 2]
x2 = x[..., self.dim // 2 :]
rotated = torch.cat([-x2, x1], dim=-1)
return x * cos + rotated * sin
3.3 YaRN (Yet another RoPE extension method)¶
class YaRNRoPE(nn.Module):
"""
YaRN: NTK-aware Interpolation
Position Interpolation์ ๋ฌธ์ :
- ๊ณ ์ฃผํ ์ ๋ณด ์์ค (๋์ ์ฐจ์)
YaRN ํด๊ฒฐ์ฑ
:
- ์ ์ฃผํ: ๋ณด๊ฐ (interpolation)
- ๊ณ ์ฃผํ: ์ธ์ฝ (extrapolation)
- NTK ์ค์ผ์ผ๋ง์ผ๋ก ์ฃผํ์ ์กฐ์
"""
def __init__(
self,
dim: int,
original_max_length: int = 4096,
extended_max_length: int = 32768,
base: float = 10000.0,
beta_fast: float = 32,
beta_slow: float = 1,
):
super().__init__()
self.dim = dim
self.original_max_length = original_max_length
self.extended_max_length = extended_max_length
scale = extended_max_length / original_max_length
# ์ฐจ์๋ณ ๋ณด๊ฐ ๋น์จ ๊ณ์ฐ
# ์ ์ฃผํ (๋ฎ์ ์ฐจ์): ๋ ๋ง์ด ๋ณด๊ฐ
# ๊ณ ์ฃผํ (๋์ ์ฐจ์): ๋ ๋ณด๊ฐ (์ธ์ฝ์ ๊ฐ๊น์)
dims = torch.arange(0, dim, 2)
low = max(0, math.floor(dim * math.log(scale) / (2 * math.log(original_max_length))))
high = min(dim // 2 - 1, math.ceil(dim * math.log(scale) / (2 * math.log(original_max_length))))
# ๋จํ ํจ์๋ก ๋ณด๊ฐ/์ธ์ฝ ๋น์จ ๊ฒฐ์
ramp = torch.zeros(dim // 2)
ramp[:low] = 0.0 # ์ธ์ฝ
ramp[high:] = 1.0 # ๋ณด๊ฐ
if high > low:
ramp[low:high] = (dims[low:high] - low) / (high - low)
# NTK-aware base ์กฐ์
inv_freq = 1.0 / (base ** (dims.float() / dim))
# ๋ณด๊ฐ๊ณผ ์ธ์ฝ์ ํผํฉ
inv_freq_inter = inv_freq / scale
self.register_buffer(
'inv_freq',
(1 - ramp) * inv_freq + ramp * inv_freq_inter
)
# Attention scaling
self.mscale = 0.1 * math.log(scale) + 1.0
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
# ์ฃผํ์ ๊ณ์ฐ (์ด๋ฏธ ์กฐ์ ๋ inv_freq ์ฌ์ฉ)
freqs = torch.einsum('bi,d->bid', position_ids.float(), self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().unsqueeze(2) * self.mscale
sin = emb.sin().unsqueeze(2) * self.mscale
return self._apply_rope(x, cos, sin)
4. ALiBi (Attention with Linear Biases)¶
4.1 ๊ฐ๋ ¶
ALiBi: ํ์ต ์๋ ์์น ์ธ์ฝ๋ฉ
์์ด๋์ด:
- ์์น ์ธ์ฝ๋ฉ์ ์ฌ์ฉํ์ง ์์
- ๋์ , attention ์ ์์ ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ ํจ๋ํฐ ์ถ๊ฐ
- ๋ฉ๋ฆฌ ์๋ ํ ํฐ์ผ์๋ก attention ์ ์ ๊ฐ์
Attention score modification:
score(q_i, k_j) = q_i ยท k_j - m ร |i - j|
m: head๋ณ ๊ธฐ์ธ๊ธฐ (๊ณ ์ , ํ์ต ์ ํจ)
m_h = 2^(-8/H) for head h (H = ์ด head ์)
4.2 ๊ตฌํ¶
class ALiBiAttention(nn.Module):
"""ALiBi: Attention with Linear Biases"""
def __init__(
self,
hidden_size: int,
num_heads: int,
max_seq_len: int = 8192
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
# ALiBi slopes: ๊ธฐํ๊ธ์์ ์ผ๋ก ๊ฐ์
# 2^(-8/n), 2^(-8*2/n), ..., 2^(-8)
slopes = self._get_alibi_slopes(num_heads)
self.register_buffer('slopes', slopes)
# ๊ฑฐ๋ฆฌ ํ๋ ฌ ์ฌ์ ๊ณ์ฐ
positions = torch.arange(max_seq_len)
distance_matrix = positions.unsqueeze(0) - positions.unsqueeze(1)
distance_matrix = distance_matrix.abs()
self.register_buffer('distance_matrix', distance_matrix)
def _get_alibi_slopes(self, num_heads: int) -> torch.Tensor:
"""Head๋ณ ALiBi slope ๊ณ์ฐ"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(num_heads).is_integer():
slopes = get_slopes_power_of_2(num_heads)
else:
# ๊ฐ์ฅ ๊ฐ๊น์ด 2์ ๊ฑฐ๋ญ์ ๊ณฑ์ผ๋ก ๋ณด๊ฐ
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
slopes = get_slopes_power_of_2(closest_power_of_2)
extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)
extra_slopes = extra_slopes[0::2][:num_heads - closest_power_of_2]
slopes = slopes + extra_slopes
return torch.tensor(slopes).view(1, num_heads, 1, 1)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# Q, K, V ๊ณ์ฐ
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
# Reshape
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose: (batch, heads, seq, dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# ALiBi bias: -m ร |i - j|
alibi_bias = -self.slopes * self.distance_matrix[:seq_len, :seq_len]
scores = scores + alibi_bias
# Causal mask
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=scores.device) * float('-inf'),
diagonal=1
)
scores = scores + causal_mask
# Attention weights
weights = F.softmax(scores, dim=-1)
# Output
output = torch.matmul(weights, V)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, self.hidden_size)
output = self.output(output)
return output
5. Ring Attention¶
5.1 ๊ฐ๋ ¶
Ring Attention: ๋ถ์ฐ Long Context
์์ด๋์ด:
- ์ํ์ค๋ฅผ ์ฌ๋ฌ GPU์ ๋ถ์ฐ
- ๊ฐ GPU๊ฐ ์์ ์ ์ฒญํฌ + ์ํํ๋ KV ์ฒ๋ฆฌ
- ํต์ ๊ณผ ๊ณ์ฐ ์ค๋ฒ๋ฉ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Ring Attention โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ GPU 0: Q[0:n/4] GPU 1: Q[n/4:n/2] โ
โ โ KV ์ํ โ KV ์ํ โ
โ Step 1: K[0:n/4] Step 1: K[n/4:n/2] โ
โ Step 2: K[n/4:n/2] Step 2: K[n/2:3n/4] โ
โ Step 3: K[n/2:3n/4] Step 3: K[3n/4:n] โ
โ Step 4: K[3n/4:n] Step 4: K[0:n/4] โ
โ โ
โ KV๊ฐ ๋ง์ฒ๋ผ ์ํํ๋ฉฐ ๊ฐ GPU์ Q์ ๊ฒฐํฉ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
5.2 ๊ตฌํ ๊ฐ์¶
import torch.distributed as dist
def ring_attention_forward(
Q: torch.Tensor, # Local Q chunk
K: torch.Tensor, # Local K chunk
V: torch.Tensor, # Local V chunk
world_size: int,
rank: int
):
"""
Ring Attention Forward Pass (๊ฐ๋
์ ๊ตฌํ)
์ค์ ๊ตฌํ์ CUDA ์ปค๋๊ณผ ๋ณต์กํ ๋๊ธฐํ ํ์
"""
local_seq_len = Q.shape[1]
# ๋์ attention ์ถ๋ ฅ
output = torch.zeros_like(Q)
max_scores = torch.full((Q.shape[0], Q.shape[2], local_seq_len), float('-inf'))
sum_exp = torch.zeros_like(max_scores)
# ํ์ฌ KV
current_K = K.clone()
current_V = V.clone()
for step in range(world_size):
# ์ด ์ฒญํฌ์ KV์ ๋ํด attention ๊ณ์ฐ
scores = torch.matmul(Q, current_K.transpose(-2, -1))
scores = scores / math.sqrt(Q.shape[-1])
# Online softmax (numerically stable)
new_max = torch.max(scores.max(dim=-1).values, max_scores)
exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
# ์ด์ ๊ฒฐ๊ณผ ์ค์ผ์ผ ์กฐ์
scale = torch.exp(max_scores - new_max)
output = output * scale.unsqueeze(-1) + torch.matmul(exp_scores, current_V)
sum_exp = sum_exp * scale + exp_scores.sum(dim=-1)
max_scores = new_max
# KV๋ฅผ ๋ค์ GPU๋ก ์ ์ก (ring)
if step < world_size - 1:
# ๋น๋๊ธฐ send/recv
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
# ๋ค์ GPU์์ KV ์์
current_K = ring_pass(current_K, send_rank, recv_rank)
current_V = ring_pass(current_V, send_rank, recv_rank)
# ์ต์ข
์ ๊ทํ
output = output / sum_exp.unsqueeze(-1)
return output
def ring_pass(tensor, send_rank, recv_rank):
"""Ring topology์์ ํ
์ ์ ๋ฌ"""
recv_tensor = torch.empty_like(tensor)
send_op = dist.isend(tensor, send_rank)
recv_op = dist.irecv(recv_tensor, recv_rank)
send_op.wait()
recv_op.wait()
return recv_tensor
6. ์ค์ฉ์ ๊ฐ์ด๋¶
6.1 ์ปจํ ์คํธ ํ์ฅ ๋ฐฉ๋ฒ ์ ํ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ ์ธ์ ์ด๋ค ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ๊น? โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ 4K โ 8K: โ
โ - Position Interpolation (๊ฐ๋จ, ์ฑ๋ฅ ์ข์) โ
โ - ์ฝ๊ฐ์ fine-tuning ๊ถ์ฅ โ
โ โ
โ 4K โ 32K: โ
โ - YaRN (PI๋ณด๋ค ์ฑ๋ฅ ์ข์) โ
โ - ๋๋ ALiBi (์ฒ์๋ถํฐ ํ์ต ์) โ
โ โ
โ 32K โ 100K+: โ
โ - Flash Attention ํ์ โ
โ - Ring Attention (๋ค์ค GPU) โ
โ - Sparse Attention ๊ณ ๋ ค โ
โ โ
โ 1M+: โ
โ - ํน์ ์ํคํ
์ฒ ํ์ โ
โ - Mamba/State Space Models โ
โ - ๋๋ ๊ทน๋๋ก ํฌ์ํ attention โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
6.2 ์ค์ ํ¶
# 1. Gradient Checkpointing์ ํ์
model.gradient_checkpointing_enable()
# 2. Mixed Precision ์ฌ์ฉ
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(**inputs)
# 3. KV Cache ์ต์ ํ (์ถ๋ก ์)
# - Sliding Window Cache
# - Paged Attention (vLLM)
# 4. ์ฒญํฌ ๋จ์ ์ฒ๋ฆฌ (๊ธด ๋ฌธ์)
def process_long_document(model, document, chunk_size=4096, overlap=512):
"""๊ธด ๋ฌธ์๋ฅผ ์ฒญํฌ๋ก ๋๋ ์ฒ๋ฆฌ"""
tokens = tokenizer.encode(document)
results = []
for i in range(0, len(tokens), chunk_size - overlap):
chunk = tokens[i:i + chunk_size]
output = model.generate(chunk)
results.append(output)
return merge_results(results)
์ฐธ๊ณ ์๋ฃ¶
๋ ผ๋ฌธ¶
- Beltagy et al. (2020). "Longformer: The Long-Document Transformer"
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention"
- Press et al. (2021). "Train Short, Test Long: Attention with Linear Biases"
- Peng et al. (2023). "YaRN: Efficient Context Window Extension of Large Language Models"
๊ด๋ จ ๋ ์จ¶
- 08_LLaMA_Family.md - RoPE ๊ธฐ๋ณธ
- 09_Mistral_MoE.md - Sliding Window Attention