18. Attention 메커니즘 심화

18. Attention 메커니즘 심화

학습 목표

  • Multi-Head Attention 수학적 세부 이해
  • Attention 복잡도 분석 (O(n^2))
  • Flash Attention 원리
  • Sparse Attention 기법들
  • 위치 인코딩 심화 (RoPE, ALiBi)
  • Attention 시각화 기법
  • PyTorch 효율적 구현

1. Multi-Head Attention 수학

수식 복습

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

Q = XW_Q  (query)
K = XW_K  (key)
V = XW_V  (value)

X: 입력 (batch, seq_len, d_model)
W_Q, W_K, W_V: 학습 가능한 가중치

상세 차원 분석

# 입력
# X: (batch, seq_len, d_model)
# 예: (32, 512, 768)

# 가중치 행렬
# W_Q, W_K, W_V: (d_model, d_model)
# 예: (768, 768)

# Q, K, V 계산
# Q = X @ W_Q: (32, 512, 768)
# K = X @ W_K: (32, 512, 768)
# V = X @ W_V: (32, 512, 768)

# Multi-Head 분할
# num_heads = 12, head_dim = 768 / 12 = 64
# Q: (32, 512, 768) → (32, 12, 512, 64)
# K: (32, 512, 768) → (32, 12, 512, 64)
# V: (32, 512, 768) → (32, 12, 512, 64)

# Attention Score
# scores = Q @ K^T: (32, 12, 512, 64) @ (32, 12, 64, 512) = (32, 12, 512, 512)

# Scaled
# scores = scores / sqrt(64) = scores / 8

# Softmax
# weights = softmax(scores): (32, 12, 512, 512)

# Attention Output
# output = weights @ V: (32, 12, 512, 512) @ (32, 12, 512, 64) = (32, 12, 512, 64)

# Concat
# output: (32, 12, 512, 64) → (32, 512, 768)

# Output Projection
# output = output @ W_O: (32, 512, 768)

PyTorch 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention 상세 구현 (⭐⭐⭐)"""
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = math.sqrt(self.head_dim)

        # 가중치 행렬
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None, return_attention=False):
        batch_size, seq_len, _ = query.size()

        # 1. Linear projections
        Q = self.W_q(query)  # (batch, seq, d_model)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Split into heads
        # (batch, seq, d_model) -> (batch, num_heads, seq, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 3. Scaled dot-product attention
        # (batch, heads, seq_q, head_dim) @ (batch, heads, head_dim, seq_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # 4. Apply mask (optional)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 5. Softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 6. Apply attention to values
        # (batch, heads, seq_q, seq_k) @ (batch, heads, seq_k, head_dim)
        attention_output = torch.matmul(attention_weights, V)

        # 7. Concatenate heads
        # (batch, heads, seq, head_dim) -> (batch, seq, d_model)
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, self.d_model)

        # 8. Output projection
        output = self.W_o(attention_output)

        if return_attention:
            return output, attention_weights
        return output

2. Attention 복잡도 분석

시간 복잡도

Q @ K^T: O(n * n * d) = O(n^2 * d)
    Q: (n, d), K^T: (d, n) → 결과: (n, n)

softmax: O(n^2)

weights @ V: O(n^2 * d)
    weights: (n, n), V: (n, d)

Total: O(n^2 * d)

공간 복잡도

Attention matrix: O(n^2)
    scores: (n, n) 저장 필요

긴 시퀀스 문제:
    n = 10,000, heads = 12
    메모리 = 12 * 10000 * 10000 * 4 bytes = 4.8 GB (float32)

병목점

# 메모리 병목 시각화

n = [1000, 2000, 4000, 8000, 16000]
memory_gb = [h * n_i * n_i * 4 / (1024**3) for n_i in n for h in [12]]

# n=1000:  ~0.048 GB
# n=4000:  ~0.77 GB
# n=8000:  ~3.1 GB
# n=16000: ~12.3 GB

# GPU 메모리 한계로 긴 시퀀스 처리 불가

3. Flash Attention

핵심 아이디어

문제: attention matrix (n x n)  번에 메모리에 올림  메모리 폭발

해결: 블록 단위로 계산, SRAM(빠른 on-chip 메모리) 활용

1. Q, K, V를 블록으로 분할
2.  블록에서 attention 계산 (SRAM에서)
3. Online softmax로 블록 결과 병합

Online Softmax

def online_softmax(scores_blocks):
    """Online Softmax 알고리즘 (⭐⭐⭐⭐)

    블록별로 softmax를 계산하고 병합
    전체 점수를 메모리에 올리지 않음
    """
    # Block 1: [s1, s2, s3]
    # Block 2: [s4, s5, s6]

    # 일반 softmax:
    # exp(s1 - max_all) / sum_all_exp

    # Online:
    # 블록마다 local_max, local_sum 유지
    # 새 블록 볼 때 global_max, global_sum 업데이트

    max_so_far = float('-inf')
    sum_so_far = 0

    for block in scores_blocks:
        block_max = block.max()
        new_max = max(max_so_far, block_max)

        # 기존 sum 조정
        sum_so_far = sum_so_far * math.exp(max_so_far - new_max)
        # 새 블록 추가
        sum_so_far += (block - new_max).exp().sum()

        max_so_far = new_max

    return max_so_far, sum_so_far

PyTorch 2.0+ Flash Attention

# PyTorch 2.0부터 내장 지원

def flash_attention_example():
    """Flash Attention 사용 예시 (⭐⭐⭐)"""
    batch_size = 32
    seq_len = 4096
    num_heads = 12
    head_dim = 64

    Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
    K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
    V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')

    # PyTorch 2.0+ scaled_dot_product_attention
    # Flash Attention 자동 적용 (조건 충족 시)
    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,
        enable_mem_efficient=False
    ):
        output = F.scaled_dot_product_attention(Q, K, V)

    return output

# 메모리 비교
# 일반: O(n^2)
# Flash: O(n) - attention matrix를 저장하지 않음

성능 비교

def benchmark_attention(seq_lens, batch_size=8, num_heads=12, head_dim=64):
    """Attention 성능 벤치마크 (⭐⭐)"""
    import time

    results = []

    for seq_len in seq_lens:
        Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
        K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
        V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')

        # Warmup
        _ = F.scaled_dot_product_attention(Q, K, V)
        torch.cuda.synchronize()

        # Timing
        start = time.time()
        for _ in range(100):
            _ = F.scaled_dot_product_attention(Q, K, V)
        torch.cuda.synchronize()
        elapsed = (time.time() - start) / 100

        # Memory
        torch.cuda.reset_peak_memory_stats()
        _ = F.scaled_dot_product_attention(Q, K, V)
        memory = torch.cuda.max_memory_allocated() / 1e9

        results.append({
            'seq_len': seq_len,
            'time_ms': elapsed * 1000,
            'memory_gb': memory
        })

    return results

4. Sparse Attention

Local Attention

class LocalAttention(nn.Module):
    """Local (Sliding Window) Attention (⭐⭐⭐)

    각 토큰이 주변 window_size 토큰만 참조
    복잡도: O(n * window_size)
    """
    def __init__(self, d_model, num_heads, window_size=256):
        super().__init__()
        self.window_size = window_size
        self.attention = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        batch, seq_len, d_model = x.shape

        # 마스크 생성: window 밖은 -inf
        mask = torch.zeros(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1

        return self.attention(x, x, x, mask=mask)

Strided Attention

class StridedAttention(nn.Module):
    """Strided Attention (⭐⭐⭐)

    일정 간격으로 떨어진 토큰만 참조
    예: stride=4면 0, 4, 8, 12, ... 만 참조
    """
    def __init__(self, d_model, num_heads, stride=4):
        super().__init__()
        self.stride = stride
        self.attention = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        # 마스크: stride 간격 토큰만 attend
        mask = torch.zeros(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            indices = list(range(0, seq_len, self.stride))
            mask[i, indices] = 1

        return self.attention(x, x, x, mask=mask)

BigBird Attention

class BigBirdAttention(nn.Module):
    """BigBird Sparse Attention (⭐⭐⭐⭐)

    3가지 패턴 조합:
    1. Local: 주변 토큰
    2. Global: 특정 토큰(CLS 등)이 모든 토큰 참조
    3. Random: 랜덤 토큰 참조
    """
    def __init__(self, d_model, num_heads, window_size=64, num_global=2, num_random=3):
        super().__init__()
        self.window_size = window_size
        self.num_global = num_global
        self.num_random = num_random
        self.attention = MultiHeadAttention(d_model, num_heads)

    def create_bigbird_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)

        # 1. Local attention
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1

        # 2. Global tokens (처음 num_global 토큰)
        mask[:self.num_global, :] = 1  # global → all
        mask[:, :self.num_global] = 1  # all → global

        # 3. Random connections
        for i in range(seq_len):
            random_indices = torch.randperm(seq_len)[:self.num_random]
            mask[i, random_indices] = 1

        return mask

    def forward(self, x):
        batch, seq_len, _ = x.shape
        mask = self.create_bigbird_mask(seq_len, x.device)
        return self.attention(x, x, x, mask=mask.unsqueeze(0).unsqueeze(0))

Longformer Attention

class LongformerAttention(nn.Module):
    """Longformer Attention (⭐⭐⭐⭐)

    Local + Global tokens
    Global: 특정 토큰 (CLS, SEP 등)만 전체 참조
    """
    def __init__(self, d_model, num_heads, window_size=256, global_indices=None):
        super().__init__()
        self.window_size = window_size
        self.global_indices = global_indices or [0]  # 기본: CLS만
        self.attention = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        mask = torch.zeros(seq_len, seq_len, device=x.device)

        # Local
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1

        # Global
        for idx in self.global_indices:
            mask[idx, :] = 1
            mask[:, idx] = 1

        return self.attention(x, x, x, mask=mask)

5. 위치 인코딩 심화

Sinusoidal (원본 Transformer)

class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal 위치 인코딩 (⭐⭐)"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Learned Position Embeddings

class LearnedPositionalEncoding(nn.Module):
    """학습 가능한 위치 인코딩 (⭐⭐)"""
    def __init__(self, d_model, max_len=512):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return x + self.pos_embedding(positions)

RoPE (Rotary Position Embedding)

class RotaryPositionalEmbedding(nn.Module):
    """RoPE - LLaMA, GPT-NeoX 등에서 사용 (⭐⭐⭐⭐)

    위치 정보를 회전 행렬로 인코딩
    상대 위치 정보를 자연스럽게 표현
    """
    def __init__(self, dim, max_len=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_len = max_len

        # 캐싱
        self._set_cos_sin_cache(max_len)

    def _set_cos_sin_cache(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def forward(self, q, k, seq_len):
        # q, k: (batch, heads, seq, head_dim)
        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)

        q_rotated = self._apply_rotary(q, cos, sin)
        k_rotated = self._apply_rotary(k, cos, sin)

        return q_rotated, k_rotated

    def _apply_rotary(self, x, cos, sin):
        # x: (batch, heads, seq, head_dim)
        # x를 반으로 나눠서 회전
        x1, x2 = x[..., ::2], x[..., 1::2]
        # 회전 적용
        rotated = torch.stack(
            [-x2, x1], dim=-1
        ).flatten(-2)

        return x * cos + rotated * sin


def apply_rope(q, k, rope_module, seq_len):
    """RoPE 적용 예시"""
    return rope_module(q, k, seq_len)

ALiBi (Attention with Linear Biases)

class ALiBiPositionalBias(nn.Module):
    """ALiBi - MPT, BLOOM 등에서 사용 (⭐⭐⭐⭐)

    위치 임베딩 대신 attention score에 선형 bias 추가
    학습 파라미터 없음
    길이 외삽 능력 우수
    """
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
        # 각 head마다 다른 기울기
        slopes = self._get_slopes(num_heads)
        self.register_buffer('slopes', slopes)

    def _get_slopes(self, n):
        """기울기 계산: 2^(-8/n), 2^(-16/n), ..."""
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(n).is_integer():
            return torch.tensor(get_slopes_power_of_2(n))
        else:
            closest_power = 2 ** math.floor(math.log2(n))
            return torch.tensor(
                get_slopes_power_of_2(closest_power) +
                get_slopes_power_of_2(2 * closest_power)[0::2][:n - closest_power]
            )

    def forward(self, seq_len, device):
        # 상대 거리 행렬
        positions = torch.arange(seq_len, device=device)
        relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
        relative_positions = relative_positions.abs().unsqueeze(0)

        # bias = -slope * distance
        alibi = -self.slopes.unsqueeze(1).unsqueeze(1) * relative_positions

        return alibi  # (num_heads, seq_len, seq_len)


def attention_with_alibi(Q, K, V, alibi_bias):
    """ALiBi 적용 attention (⭐⭐⭐)"""
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # ALiBi bias 추가
    scores = scores + alibi_bias

    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V)

6. Attention 시각화

Attention Weights 시각화

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens=None, layer=0, head=0):
    """단일 head attention 시각화 (⭐⭐)"""
    # attention_weights: (batch, heads, seq, seq)
    weights = attention_weights[0, head].cpu().detach().numpy()

    plt.figure(figsize=(10, 8))
    sns.heatmap(weights, cmap='Blues', square=True)

    if tokens:
        plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
        plt.yticks(range(len(tokens)), tokens, rotation=0)

    plt.title(f'Attention Weights (Layer {layer}, Head {head})')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.tight_layout()
    plt.savefig(f'attention_layer{layer}_head{head}.png')
    plt.close()


def visualize_all_heads(attention_weights, tokens=None, ncols=4):
    """모든 head 시각화 (⭐⭐)"""
    num_heads = attention_weights.size(1)
    nrows = (num_heads + ncols - 1) // ncols

    fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))
    axes = axes.flatten()

    for head in range(num_heads):
        weights = attention_weights[0, head].cpu().detach().numpy()
        ax = axes[head]
        im = ax.imshow(weights, cmap='Blues')
        ax.set_title(f'Head {head}')
        ax.axis('off')

    for i in range(num_heads, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('attention_all_heads.png')
    plt.close()

Attention Flow (BertViz 스타일)

def plot_attention_flow(attention_weights, tokens, top_k=3):
    """상위 K개 attention 연결 시각화 (⭐⭐⭐)"""
    weights = attention_weights[0, 0].cpu().detach().numpy()
    seq_len = len(tokens)

    plt.figure(figsize=(12, 8))

    # 토큰 위치
    left_positions = [(0, i) for i in range(seq_len)]
    right_positions = [(1, i) for i in range(seq_len)]

    # 토큰 표시
    for i, token in enumerate(tokens):
        plt.text(-0.1, i, token, ha='right', va='center', fontsize=10)
        plt.text(1.1, i, token, ha='left', va='center', fontsize=10)

    # 상위 K개 연결
    for i in range(seq_len):
        top_indices = weights[i].argsort()[-top_k:]
        for j in top_indices:
            alpha = weights[i, j] / weights[i].max()
            plt.plot([0, 1], [i, j], 'b-', alpha=alpha, linewidth=2)

    plt.xlim(-0.5, 1.5)
    plt.ylim(-1, seq_len)
    plt.axis('off')
    plt.title('Attention Flow')
    plt.savefig('attention_flow.png')
    plt.close()

7. 효율적 구현

PyTorch scaled_dot_product_attention

def efficient_attention_comparison():
    """효율적 attention 구현 비교 (⭐⭐⭐)"""
    import time

    batch, heads, seq, dim = 8, 12, 2048, 64

    Q = torch.randn(batch, heads, seq, dim, device='cuda')
    K = torch.randn(batch, heads, seq, dim, device='cuda')
    V = torch.randn(batch, heads, seq, dim, device='cuda')

    # 1. Naive implementation
    def naive_attention(Q, K, V):
        scale = math.sqrt(Q.size(-1))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        weights = F.softmax(scores, dim=-1)
        return torch.matmul(weights, V)

    # 2. PyTorch built-in (may use Flash Attention)
    def pytorch_attention(Q, K, V):
        return F.scaled_dot_product_attention(Q, K, V)

    # Warmup
    _ = naive_attention(Q, K, V)
    _ = pytorch_attention(Q, K, V)
    torch.cuda.synchronize()

    # Benchmark naive
    torch.cuda.reset_peak_memory_stats()
    start = time.time()
    for _ in range(10):
        _ = naive_attention(Q, K, V)
    torch.cuda.synchronize()
    naive_time = (time.time() - start) / 10
    naive_mem = torch.cuda.max_memory_allocated() / 1e9

    # Benchmark pytorch
    torch.cuda.reset_peak_memory_stats()
    start = time.time()
    for _ in range(10):
        _ = pytorch_attention(Q, K, V)
    torch.cuda.synchronize()
    pytorch_time = (time.time() - start) / 10
    pytorch_mem = torch.cuda.max_memory_allocated() / 1e9

    print(f"Naive:   {naive_time*1000:.2f}ms, {naive_mem:.2f}GB")
    print(f"PyTorch: {pytorch_time*1000:.2f}ms, {pytorch_mem:.2f}GB")

    return naive_time, pytorch_time

Memory-Efficient Attention

class MemoryEfficientAttention(nn.Module):
    """메모리 효율적 Attention (청크 기반) (⭐⭐⭐⭐)"""
    def __init__(self, d_model, num_heads, chunk_size=256):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.chunk_size = chunk_size
        self.scale = math.sqrt(self.head_dim)

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, _ = x.shape

        Q = self.W_q(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        # 청크별 처리로 메모리 절약
        outputs = []
        for i in range(0, N, self.chunk_size):
            end = min(i + self.chunk_size, N)
            Q_chunk = Q[:, :, i:end]

            # 이 청크에 대해 전체 K, V와 attention
            scores = torch.matmul(Q_chunk, K.transpose(-2, -1)) / self.scale
            weights = F.softmax(scores, dim=-1)
            chunk_out = torch.matmul(weights, V)
            outputs.append(chunk_out)

        output = torch.cat(outputs, dim=2)
        output = output.transpose(1, 2).contiguous().view(B, N, self.d_model)

        return self.W_o(output)

8. Attention 패턴 분석

자주 관찰되는 패턴

def analyze_attention_patterns(attention_weights):
    """Attention 패턴 분석 (⭐⭐⭐)"""
    # attention_weights: (batch, heads, seq, seq)

    patterns = {}

    for head in range(attention_weights.size(1)):
        weights = attention_weights[0, head].cpu().detach()

        # 1. Diagonal (자기 자신 주목)
        diagonal = torch.diag(weights).mean()

        # 2. Previous token (이전 토큰)
        prev_token = torch.diag(weights, -1).mean() if weights.size(0) > 1 else 0

        # 3. First token (CLS 등)
        first_token = weights[:, 0].mean()

        # 4. Uniform (균일 분포)
        uniform_entropy = -(weights * weights.log()).sum(dim=-1).mean()

        patterns[f'head_{head}'] = {
            'diagonal': diagonal.item(),
            'prev_token': prev_token.item() if isinstance(prev_token, torch.Tensor) else prev_token,
            'first_token': first_token.item(),
            'entropy': uniform_entropy.item()
        }

    return patterns

정리

핵심 개념

  1. MHA 수학: Q, K, V 투영 → 스케일링 → softmax → 가중합
  2. 복잡도: O(n^2) 시간/공간 - 긴 시퀀스 병목
  3. Flash Attention: 블록 처리 + Online softmax → O(n) 메모리
  4. Sparse Attention: Local, Strided, BigBird 등
  5. 위치 인코딩: Sinusoidal, Learned, RoPE, ALiBi

효율화 요약

기법 시간 공간 특징
Standard O(n^2) O(n^2) 기본
Flash O(n^2) O(n) HW 최적화
Local O(nw) O(nw) window w
Sparse O(n sqrt(n)) O(n) 패턴 조합

PyTorch 실전 팁

# 1. PyTorch 2.0+ 사용
output = F.scaled_dot_product_attention(Q, K, V)

# 2. Flash Attention 강제 활성화
with torch.backends.cuda.sdp_kernel(enable_flash=True):
    output = F.scaled_dot_product_attention(Q, K, V)

# 3. 긴 시퀀스는 청킹 고려

# 4. RoPE/ALiBi는 길이 외삽에 유리

참고 자료

  • Flash Attention: https://arxiv.org/abs/2205.14135
  • RoPE: https://arxiv.org/abs/2104.09864
  • ALiBi: https://arxiv.org/abs/2108.12409
  • BigBird: https://arxiv.org/abs/2007.14062
to navigate between lessons