17. Attention Mechanism Deep Dive

17. Attention Mechanism Deep Dive

Previous: Attention & Transformer | Next: Transformer Implementation


Learning Objectives

  • In-depth understanding of Multi-Head Attention mathematics
  • Attention complexity analysis (O(n^2))
  • Flash Attention principles
  • Sparse Attention techniques
  • Advanced positional encoding (RoPE, ALiBi)
  • Attention visualization techniques
  • Efficient PyTorch implementation

1. Multi-Head Attention Mathematics

Formula Review

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

Q = XW_Q  (query)
K = XW_K  (key)
V = XW_V  (value)

X: input (batch, seq_len, d_model)
W_Q, W_K, W_V: learnable weights

Detailed Dimension Analysis

# Input
# X: (batch, seq_len, d_model)
# Example: (32, 512, 768)

# Weight matrices
# W_Q, W_K, W_V: (d_model, d_model)
# Example: (768, 768)

# Q, K, V computation
# Q = X @ W_Q: (32, 512, 768)
# K = X @ W_K: (32, 512, 768)
# V = X @ W_V: (32, 512, 768)

# Multi-Head split
# num_heads = 12, head_dim = 768 / 12 = 64
# Q: (32, 512, 768) → (32, 12, 512, 64)
# K: (32, 512, 768) → (32, 12, 512, 64)
# V: (32, 512, 768) → (32, 12, 512, 64)

# Attention Score
# scores = Q @ K^T: (32, 12, 512, 64) @ (32, 12, 64, 512) = (32, 12, 512, 512)

# Scaled
# scores = scores / sqrt(64) = scores / 8

# Softmax
# weights = softmax(scores): (32, 12, 512, 512)

# Attention Output
# output = weights @ V: (32, 12, 512, 512) @ (32, 12, 512, 64) = (32, 12, 512, 64)

# Concat
# output: (32, 12, 512, 64) → (32, 512, 768)

# Output Projection
# output = output @ W_O: (32, 512, 768)

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """Detailed Multi-Head Attention Implementation (⭐⭐⭐)"""
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = math.sqrt(self.head_dim)

        # Weight matrices
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None, return_attention=False):
        batch_size, seq_len, _ = query.size()

        # 1. Linear projections
        Q = self.W_q(query)  # (batch, seq, d_model)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Split into heads
        # (batch, seq, d_model) -> (batch, num_heads, seq, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 3. Scaled dot-product attention
        # (batch, heads, seq_q, head_dim) @ (batch, heads, head_dim, seq_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # 4. Apply mask (optional)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 5. Softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 6. Apply attention to values
        # (batch, heads, seq_q, seq_k) @ (batch, heads, seq_k, head_dim)
        attention_output = torch.matmul(attention_weights, V)

        # 7. Concatenate heads
        # (batch, heads, seq, head_dim) -> (batch, seq, d_model)
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, self.d_model)

        # 8. Output projection
        output = self.W_o(attention_output)

        if return_attention:
            return output, attention_weights
        return output

2. Attention Complexity Analysis

Time Complexity

Q @ K^T: O(n * n * d) = O(n^2 * d)
    Q: (n, d), K^T: (d, n) → result: (n, n)

softmax: O(n^2)

weights @ V: O(n^2 * d)
    weights: (n, n), V: (n, d)

Total: O(n^2 * d)

Space Complexity

Attention matrix: O(n^2)
    scores: (n, n) storage required

Long sequence problem:
    n = 10,000, heads = 12
    memory = 12 * 10000 * 10000 * 4 bytes = 4.8 GB (float32)

Bottleneck

# Memory bottleneck visualization

n = [1000, 2000, 4000, 8000, 16000]
memory_gb = [h * n_i * n_i * 4 / (1024**3) for n_i in n for h in [12]]

# n=1000:  ~0.048 GB
# n=4000:  ~0.77 GB
# n=8000:  ~3.1 GB
# n=16000: ~12.3 GB

# Cannot process long sequences due to GPU memory limits

3. Flash Attention

Key Idea

Problem: loading entire attention matrix (n x n) into memory  memory explosion

Solution: compute in blocks, utilize SRAM (fast on-chip memory)

1. Split Q, K, V into blocks
2. Compute attention per block (in SRAM)
3. Merge block results with online softmax

Online Softmax

def online_softmax(scores_blocks):
    """Online Softmax Algorithm (⭐⭐⭐⭐)

    Compute and merge softmax per block
    Avoid loading all scores into memory
    """
    # Block 1: [s1, s2, s3]
    # Block 2: [s4, s5, s6]

    # Standard softmax:
    # exp(s1 - max_all) / sum_all_exp

    # Online:
    # Maintain local_max, local_sum per block
    # Update global_max, global_sum when seeing new blocks

    max_so_far = float('-inf')
    sum_so_far = 0

    for block in scores_blocks:
        block_max = block.max()
        new_max = max(max_so_far, block_max)

        # Adjust existing sum
        sum_so_far = sum_so_far * math.exp(max_so_far - new_max)
        # Add new block
        sum_so_far += (block - new_max).exp().sum()

        max_so_far = new_max

    return max_so_far, sum_so_far

PyTorch 2.0+ Flash Attention

# Built-in support from PyTorch 2.0

def flash_attention_example():
    """Flash Attention Usage Example (⭐⭐⭐)"""
    batch_size = 32
    seq_len = 4096
    num_heads = 12
    head_dim = 64

    Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
    K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
    V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')

    # PyTorch 2.0+ scaled_dot_product_attention
    # Flash Attention automatically applied (when conditions met)
    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,
        enable_mem_efficient=False
    ):
        output = F.scaled_dot_product_attention(Q, K, V)

    return output

# Memory comparison
# Standard: O(n^2)
# Flash: O(n) - doesn't store attention matrix

Performance Comparison

def benchmark_attention(seq_lens, batch_size=8, num_heads=12, head_dim=64):
    """Attention Performance Benchmark (⭐⭐)"""
    import time

    results = []

    for seq_len in seq_lens:
        Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
        K = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
        V = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')

        # Warmup
        _ = F.scaled_dot_product_attention(Q, K, V)
        torch.cuda.synchronize()

        # Timing
        start = time.time()
        for _ in range(100):
            _ = F.scaled_dot_product_attention(Q, K, V)
        torch.cuda.synchronize()
        elapsed = (time.time() - start) / 100

        # Memory
        torch.cuda.reset_peak_memory_stats()
        _ = F.scaled_dot_product_attention(Q, K, V)
        memory = torch.cuda.max_memory_allocated() / 1e9

        results.append({
            'seq_len': seq_len,
            'time_ms': elapsed * 1000,
            'memory_gb': memory
        })

    return results

4. Sparse Attention

Local Attention

class LocalAttention(nn.Module):
    """Local (Sliding Window) Attention (⭐⭐⭐)

    Each token attends only to nearby window_size tokens
    Complexity: O(n * window_size)
    """
    def __init__(self, d_model, num_heads, window_size=256):
        super().__init__()
        self.window_size = window_size
        self.attention = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        batch, seq_len, d_model = x.shape

        # Create mask: outside window is -inf
        mask = torch.zeros(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1

        return self.attention(x, x, x, mask=mask)

Strided Attention

class StridedAttention(nn.Module):
    """Strided Attention (⭐⭐⭐)

    Attend to tokens at fixed intervals
    Example: stride=4 attends to 0, 4, 8, 12, ...
    """
    def __init__(self, d_model, num_heads, stride=4):
        super().__init__()
        self.stride = stride
        self.attention = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        # Mask: only attend to stride-spaced tokens
        mask = torch.zeros(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            indices = list(range(0, seq_len, self.stride))
            mask[i, indices] = 1

        return self.attention(x, x, x, mask=mask)

BigBird Attention

class BigBirdAttention(nn.Module):
    """BigBird Sparse Attention (⭐⭐⭐⭐)

    Combines 3 patterns:
    1. Local: nearby tokens
    2. Global: special tokens (CLS etc) attend to all
    3. Random: random token connections
    """
    def __init__(self, d_model, num_heads, window_size=64, num_global=2, num_random=3):
        super().__init__()
        self.window_size = window_size
        self.num_global = num_global
        self.num_random = num_random
        self.attention = MultiHeadAttention(d_model, num_heads)

    def create_bigbird_mask(self, seq_len, device):
        mask = torch.zeros(seq_len, seq_len, device=device)

        # 1. Local attention
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1

        # 2. Global tokens (first num_global tokens)
        mask[:self.num_global, :] = 1  # global → all
        mask[:, :self.num_global] = 1  # all → global

        # 3. Random connections
        for i in range(seq_len):
            random_indices = torch.randperm(seq_len)[:self.num_random]
            mask[i, random_indices] = 1

        return mask

    def forward(self, x):
        batch, seq_len, _ = x.shape
        mask = self.create_bigbird_mask(seq_len, x.device)
        return self.attention(x, x, x, mask=mask.unsqueeze(0).unsqueeze(0))

Longformer Attention

class LongformerAttention(nn.Module):
    """Longformer Attention (⭐⭐⭐⭐)

    Local + Global tokens
    Global: only specific tokens (CLS, SEP etc) attend globally
    """
    def __init__(self, d_model, num_heads, window_size=256, global_indices=None):
        super().__init__()
        self.window_size = window_size
        self.global_indices = global_indices or [0]  # default: CLS only
        self.attention = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        mask = torch.zeros(seq_len, seq_len, device=x.device)

        # Local
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1

        # Global
        for idx in self.global_indices:
            mask[idx, :] = 1
            mask[:, idx] = 1

        return self.attention(x, x, x, mask=mask)

5. Advanced Positional Encoding

Sinusoidal (Original Transformer)

class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal Positional Encoding (⭐⭐)"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Learned Position Embeddings

class LearnedPositionalEncoding(nn.Module):
    """Learnable Positional Encoding (⭐⭐)"""
    def __init__(self, d_model, max_len=512):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return x + self.pos_embedding(positions)

RoPE (Rotary Position Embedding)

class RotaryPositionalEmbedding(nn.Module):
    """RoPE - Used in LLaMA, GPT-NeoX etc (⭐⭐⭐⭐)

    Encode position information via rotation matrices
    Naturally represents relative position information
    """
    def __init__(self, dim, max_len=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_len = max_len

        # Caching
        self._set_cos_sin_cache(max_len)

    def _set_cos_sin_cache(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def forward(self, q, k, seq_len):
        # q, k: (batch, heads, seq, head_dim)
        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)

        q_rotated = self._apply_rotary(q, cos, sin)
        k_rotated = self._apply_rotary(k, cos, sin)

        return q_rotated, k_rotated

    def _apply_rotary(self, x, cos, sin):
        # x: (batch, heads, seq, head_dim)
        # Split x in half and rotate
        x1, x2 = x[..., ::2], x[..., 1::2]
        # Apply rotation
        rotated = torch.stack(
            [-x2, x1], dim=-1
        ).flatten(-2)

        return x * cos + rotated * sin


def apply_rope(q, k, rope_module, seq_len):
    """RoPE Application Example"""
    return rope_module(q, k, seq_len)

ALiBi (Attention with Linear Biases)

class ALiBiPositionalBias(nn.Module):
    """ALiBi - Used in MPT, BLOOM etc (⭐⭐⭐⭐)

    Add linear bias to attention scores instead of position embeddings
    No learnable parameters
    Excellent length extrapolation capability
    """
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
        # Different slope per head
        slopes = self._get_slopes(num_heads)
        self.register_buffer('slopes', slopes)

    def _get_slopes(self, n):
        """Compute slopes: 2^(-8/n), 2^(-16/n), ..."""
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(n).is_integer():
            return torch.tensor(get_slopes_power_of_2(n))
        else:
            closest_power = 2 ** math.floor(math.log2(n))
            return torch.tensor(
                get_slopes_power_of_2(closest_power) +
                get_slopes_power_of_2(2 * closest_power)[0::2][:n - closest_power]
            )

    def forward(self, seq_len, device):
        # Relative distance matrix
        positions = torch.arange(seq_len, device=device)
        relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
        relative_positions = relative_positions.abs().unsqueeze(0)

        # bias = -slope * distance
        alibi = -self.slopes.unsqueeze(1).unsqueeze(1) * relative_positions

        return alibi  # (num_heads, seq_len, seq_len)


def attention_with_alibi(Q, K, V, alibi_bias):
    """Attention with ALiBi (⭐⭐⭐)"""
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Add ALiBi bias
    scores = scores + alibi_bias

    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V)

6. Attention Visualization

Visualizing Attention Weights

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens=None, layer=0, head=0):
    """Visualize single head attention (⭐⭐)"""
    # attention_weights: (batch, heads, seq, seq)
    weights = attention_weights[0, head].cpu().detach().numpy()

    plt.figure(figsize=(10, 8))
    sns.heatmap(weights, cmap='Blues', square=True)

    if tokens:
        plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
        plt.yticks(range(len(tokens)), tokens, rotation=0)

    plt.title(f'Attention Weights (Layer {layer}, Head {head})')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.tight_layout()
    plt.savefig(f'attention_layer{layer}_head{head}.png')
    plt.close()


def visualize_all_heads(attention_weights, tokens=None, ncols=4):
    """Visualize all heads (⭐⭐)"""
    num_heads = attention_weights.size(1)
    nrows = (num_heads + ncols - 1) // ncols

    fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))
    axes = axes.flatten()

    for head in range(num_heads):
        weights = attention_weights[0, head].cpu().detach().numpy()
        ax = axes[head]
        im = ax.imshow(weights, cmap='Blues')
        ax.set_title(f'Head {head}')
        ax.axis('off')

    for i in range(num_heads, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('attention_all_heads.png')
    plt.close()

Attention Flow (BertViz Style)

def plot_attention_flow(attention_weights, tokens, top_k=3):
    """Visualize top K attention connections (⭐⭐⭐)"""
    weights = attention_weights[0, 0].cpu().detach().numpy()
    seq_len = len(tokens)

    plt.figure(figsize=(12, 8))

    # Token positions
    left_positions = [(0, i) for i in range(seq_len)]
    right_positions = [(1, i) for i in range(seq_len)]

    # Display tokens
    for i, token in enumerate(tokens):
        plt.text(-0.1, i, token, ha='right', va='center', fontsize=10)
        plt.text(1.1, i, token, ha='left', va='center', fontsize=10)

    # Top K connections
    for i in range(seq_len):
        top_indices = weights[i].argsort()[-top_k:]
        for j in top_indices:
            alpha = weights[i, j] / weights[i].max()
            plt.plot([0, 1], [i, j], 'b-', alpha=alpha, linewidth=2)

    plt.xlim(-0.5, 1.5)
    plt.ylim(-1, seq_len)
    plt.axis('off')
    plt.title('Attention Flow')
    plt.savefig('attention_flow.png')
    plt.close()

7. Efficient Implementation

PyTorch scaled_dot_product_attention

def efficient_attention_comparison():
    """Efficient attention implementation comparison (⭐⭐⭐)"""
    import time

    batch, heads, seq, dim = 8, 12, 2048, 64

    Q = torch.randn(batch, heads, seq, dim, device='cuda')
    K = torch.randn(batch, heads, seq, dim, device='cuda')
    V = torch.randn(batch, heads, seq, dim, device='cuda')

    # 1. Naive implementation
    def naive_attention(Q, K, V):
        scale = math.sqrt(Q.size(-1))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        weights = F.softmax(scores, dim=-1)
        return torch.matmul(weights, V)

    # 2. PyTorch built-in (may use Flash Attention)
    def pytorch_attention(Q, K, V):
        return F.scaled_dot_product_attention(Q, K, V)

    # Warmup
    _ = naive_attention(Q, K, V)
    _ = pytorch_attention(Q, K, V)
    torch.cuda.synchronize()

    # Benchmark naive
    torch.cuda.reset_peak_memory_stats()
    start = time.time()
    for _ in range(10):
        _ = naive_attention(Q, K, V)
    torch.cuda.synchronize()
    naive_time = (time.time() - start) / 10
    naive_mem = torch.cuda.max_memory_allocated() / 1e9

    # Benchmark pytorch
    torch.cuda.reset_peak_memory_stats()
    start = time.time()
    for _ in range(10):
        _ = pytorch_attention(Q, K, V)
    torch.cuda.synchronize()
    pytorch_time = (time.time() - start) / 10
    pytorch_mem = torch.cuda.max_memory_allocated() / 1e9

    print(f"Naive:   {naive_time*1000:.2f}ms, {naive_mem:.2f}GB")
    print(f"PyTorch: {pytorch_time*1000:.2f}ms, {pytorch_mem:.2f}GB")

    return naive_time, pytorch_time

Memory-Efficient Attention

class MemoryEfficientAttention(nn.Module):
    """Memory Efficient Attention (chunk-based) (⭐⭐⭐⭐)"""
    def __init__(self, d_model, num_heads, chunk_size=256):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.chunk_size = chunk_size
        self.scale = math.sqrt(self.head_dim)

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, _ = x.shape

        Q = self.W_q(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        # Process in chunks to save memory
        outputs = []
        for i in range(0, N, self.chunk_size):
            end = min(i + self.chunk_size, N)
            Q_chunk = Q[:, :, i:end]

            # Attention for this chunk against all K, V
            scores = torch.matmul(Q_chunk, K.transpose(-2, -1)) / self.scale
            weights = F.softmax(scores, dim=-1)
            chunk_out = torch.matmul(weights, V)
            outputs.append(chunk_out)

        output = torch.cat(outputs, dim=2)
        output = output.transpose(1, 2).contiguous().view(B, N, self.d_model)

        return self.W_o(output)

8. Attention Pattern Analysis

Common Observed Patterns

def analyze_attention_patterns(attention_weights):
    """Attention Pattern Analysis (⭐⭐⭐)"""
    # attention_weights: (batch, heads, seq, seq)

    patterns = {}

    for head in range(attention_weights.size(1)):
        weights = attention_weights[0, head].cpu().detach()

        # 1. Diagonal (self-attention)
        diagonal = torch.diag(weights).mean()

        # 2. Previous token
        prev_token = torch.diag(weights, -1).mean() if weights.size(0) > 1 else 0

        # 3. First token (CLS etc)
        first_token = weights[:, 0].mean()

        # 4. Uniform (entropy)
        uniform_entropy = -(weights * weights.log()).sum(dim=-1).mean()

        patterns[f'head_{head}'] = {
            'diagonal': diagonal.item(),
            'prev_token': prev_token.item() if isinstance(prev_token, torch.Tensor) else prev_token,
            'first_token': first_token.item(),
            'entropy': uniform_entropy.item()
        }

    return patterns

Summary

Key Concepts

  1. MHA Mathematics: Q, K, V projection → scaling → softmax → weighted sum
  2. Complexity: O(n^2) time/space - bottleneck for long sequences
  3. Flash Attention: Block processing + online softmax → O(n) memory
  4. Sparse Attention: Local, Strided, BigBird etc
  5. Positional Encoding: Sinusoidal, Learned, RoPE, ALiBi

Efficiency Summary

Technique Time Space Features
Standard O(n^2) O(n^2) Baseline
Flash O(n^2) O(n) HW optimized
Local O(nw) O(nw) window w
Sparse O(n sqrt(n)) O(n) Pattern combination

PyTorch Practical Tips

# 1. Use PyTorch 2.0+
output = F.scaled_dot_product_attention(Q, K, V)

# 2. Force Flash Attention activation
with torch.backends.cuda.sdp_kernel(enable_flash=True):
    output = F.scaled_dot_product_attention(Q, K, V)

# 3. Consider chunking for long sequences

# 4. RoPE/ALiBi favorable for length extrapolation

References

  • Flash Attention: https://arxiv.org/abs/2205.14135
  • RoPE: https://arxiv.org/abs/2104.09864
  • ALiBi: https://arxiv.org/abs/2108.12409
  • BigBird: https://arxiv.org/abs/2007.14062
to navigate between lessons