06. Pre-training Infrastructure

06. Pre-training Infrastructure

Overview

Training large-scale Foundation Models runs on thousands of GPUs for weeks to months. This lesson covers distributed training strategies, memory optimization, and training stability techniques.


1. Distributed Training Paradigms

1.1 Parallelization Strategy Overview

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                     Distributed Training Paradigms                β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                  β”‚
β”‚  Data Parallelism (DP)         Tensor Parallelism (TP)          β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”               β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚  β”‚GPU 0β”‚ β”‚GPU 1β”‚               β”‚   W = [W1 | W2]  β”‚              β”‚
β”‚  β”‚Modelβ”‚ β”‚Modelβ”‚               β”‚GPU0    GPU1      β”‚              β”‚
β”‚  β”‚Data1β”‚ β”‚Data2β”‚               β”‚ W1      W2       β”‚              β”‚
β”‚  β””β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”˜               β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚  Same model, different data    Split layers across GPUs         β”‚
β”‚                                                                  β”‚
β”‚  Pipeline Parallelism (PP)     Sequence Parallelism (SP)        β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”               β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”             β”‚
β”‚  β”‚GPU 0β”‚ β”‚GPU 1β”‚               β”‚ S1 β”‚ S2 β”‚ S3 β”‚ S4 β”‚             β”‚
β”‚  β”‚L1-L6β”‚β†’β”‚L7-12β”‚               β”‚GPU0β”‚GPU1β”‚GPU2β”‚GPU3β”‚             β”‚
β”‚  β””β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”˜               β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”˜             β”‚
β”‚  Sequential layer split        Split sequence across GPUs       β”‚
β”‚                                                                  β”‚
β”‚  3D Parallelism: DP + TP + PP combination                       β”‚
β”‚                                                                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 Memory Analysis

def estimate_training_memory(
    num_params: int,  # Number of parameters
    batch_size: int,
    seq_len: int,
    hidden_dim: int,
    num_layers: int,
    dtype_bytes: int = 2,  # fp16/bf16 = 2, fp32 = 4
    optimizer: str = 'adam'
) -> dict:
    """
    Estimate GPU memory during training

    Memory components:
    1. Model Parameters
    2. Gradients
    3. Optimizer States
    4. Activations (forward pass)
    """

    # 1. Model parameters
    param_memory = num_params * dtype_bytes

    # 2. Gradients (same as parameters)
    grad_memory = num_params * dtype_bytes

    # 3. Optimizer States
    if optimizer == 'adam':
        # Adam: momentum(fp32) + variance(fp32)
        optimizer_memory = num_params * 4 * 2  # 8 bytes per param
    elif optimizer == 'sgd':
        optimizer_memory = num_params * 4  # momentum only
    else:
        optimizer_memory = 0

    # 4. Activations (approximation)
    # Per layer: attention + FFN activations
    bytes_per_token = hidden_dim * dtype_bytes * 10  # approximation
    activation_memory = batch_size * seq_len * bytes_per_token * num_layers

    # Activation checkpointing reduces to 1/sqrt(L)

    total = param_memory + grad_memory + optimizer_memory + activation_memory

    return {
        'parameters_gb': param_memory / 1e9,
        'gradients_gb': grad_memory / 1e9,
        'optimizer_gb': optimizer_memory / 1e9,
        'activations_gb': activation_memory / 1e9,
        'total_gb': total / 1e9
    }


# Example: 7B model
memory = estimate_training_memory(
    num_params=7e9,
    batch_size=4,
    seq_len=2048,
    hidden_dim=4096,
    num_layers=32
)

print("7B model memory estimate:")
for key, value in memory.items():
    print(f"  {key}: {value:.1f} GB")

# Output:
# parameters_gb: 14.0 GB
# gradients_gb: 14.0 GB
# optimizer_gb: 56.0 GB
# activations_gb: ~21.5 GB (batch_size=4)
# total_gb: ~105.5 GB

2. FSDP (Fully Sharded Data Parallel)

2.1 FSDP Concept

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      FSDP Operating Principle                β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                             β”‚
β”‚  Traditional DDP:                                           β”‚
β”‚  GPU 0: [Full Model] + [Data 0]                            β”‚
β”‚  GPU 1: [Full Model] + [Data 1]                            β”‚
β”‚  β†’ Full model replicated on each GPU (inefficient)         β”‚
β”‚                                                             β”‚
β”‚  FSDP (Zero Stage 3):                                       β”‚
β”‚  GPU 0: [Shard 0] + [Data 0]                               β”‚
β”‚  GPU 1: [Shard 1] + [Data 1]                               β”‚
β”‚                                                             β”‚
β”‚  Forward: All-Gather to collect full parameters            β”‚
β”‚  Backward: Reduce-Scatter to distribute gradients          β”‚
β”‚                                                             β”‚
β”‚  Memory: (Params + Grads + Optim) / N + Activations        β”‚
β”‚                                                             β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2.2 PyTorch FSDP Implementation

import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    size_based_auto_wrap_policy,
)
import functools

def setup_fsdp_training():
    """Setup FSDP training"""

    # Initialize distributed
    dist.init_process_group("nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    # Create model
    model = MyTransformerModel(config)

    # Mixed Precision settings
    mixed_precision = MixedPrecision(
        param_dtype=torch.bfloat16,      # Parameters
        reduce_dtype=torch.bfloat16,     # Gradient reduction
        buffer_dtype=torch.bfloat16,     # Buffers
    )

    # Auto Wrap Policy: Shard at Transformer layer level
    wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock},
    )

    # FSDP wrapping
    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,  # Zero-3
        mixed_precision=mixed_precision,
        auto_wrap_policy=wrap_policy,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        cpu_offload=CPUOffload(offload_params=False),
        device_id=local_rank,
    )

    return model


def train_step_fsdp(model, batch, optimizer, scaler=None):
    """FSDP training step"""
    model.train()

    # Forward
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        outputs = model(**batch)
        loss = outputs.loss

    # Backward
    loss.backward()

    # Gradient clipping (requires care with FSDP)
    model.clip_grad_norm_(max_norm=1.0)

    # Optimizer step
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()


# Checkpoint save/load
from torch.distributed.fsdp import (
    FullStateDictConfig,
    StateDictType,
)

def save_fsdp_checkpoint(model, optimizer, path):
    """Save FSDP checkpoint"""

    # Full State Dict config
    full_state_dict_config = FullStateDictConfig(
        offload_to_cpu=True,
        rank0_only=True,
    )

    with FSDP.state_dict_type(
        model,
        StateDictType.FULL_STATE_DICT,
        full_state_dict_config,
    ):
        state_dict = model.state_dict()
        optim_state = FSDP.optim_state_dict(model, optimizer)

        if dist.get_rank() == 0:
            torch.save({
                'model': state_dict,
                'optimizer': optim_state,
            }, path)

    dist.barrier()

3. DeepSpeed ZeRO

3.1 ZeRO Stage Comparison

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                     DeepSpeed ZeRO Stages                  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                            β”‚
β”‚  Stage 1: Optimizer State Partitioning                    β”‚
β”‚  - Only optimizer states (Adam m, v) partitioned          β”‚
β”‚  - Memory savings: ~4x                                     β”‚
β”‚                                                            β”‚
β”‚  Stage 2: + Gradient Partitioning                         β”‚
β”‚  - Gradients also partitioned                             β”‚
β”‚  - Memory savings: ~8x                                     β”‚
β”‚                                                            β”‚
β”‚  Stage 3: + Parameter Partitioning                        β”‚
β”‚  - Parameters also partitioned (similar to FSDP)          β”‚
β”‚  - Memory savings: ~N (proportional to GPU count)         β”‚
β”‚                                                            β”‚
β”‚  ZeRO-Offload: Offload to CPU/NVMe                        β”‚
β”‚  ZeRO-Infinity: Support for infinite model size           β”‚
β”‚                                                            β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

3.2 DeepSpeed Configuration

# ds_config.json
ds_config = {
    "train_batch_size": 256,
    "gradient_accumulation_steps": 8,
    "train_micro_batch_size_per_gpu": 4,

    # FP16 settings
    "fp16": {
        "enabled": True,
        "loss_scale": 0,  # dynamic
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    # BF16 settings (alternative)
    "bf16": {
        "enabled": False
    },

    # ZeRO Stage 3
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",  # or "nvme"
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": True,
    },

    # Gradient Checkpointing
    "activation_checkpointing": {
        "partition_activations": True,
        "cpu_checkpointing": True,
        "contiguous_memory_optimization": True,
        "number_checkpoints": None,
        "synchronize_checkpoint_boundary": False,
        "profile": False
    },

    # Optimizer
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 3e-4,
            "betas": [0.9, 0.999],
            "eps": 1e-8,
            "weight_decay": 0.01
        }
    },

    # Scheduler
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-4,
            "warmup_num_steps": 1000,
            "total_num_steps": 100000
        }
    }
}

3.3 DeepSpeed Training Code

import deepspeed
import torch

def train_with_deepspeed():
    """DeepSpeed training loop"""

    # Model and data
    model = MyTransformerModel(config)
    train_dataloader = create_dataloader(...)

    # DeepSpeed initialization
    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        config=ds_config,
    )

    # Training loop
    for epoch in range(num_epochs):
        for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(model_engine.device) for k, v in batch.items()}

            # Forward
            outputs = model_engine(**batch)
            loss = outputs.loss

            # Backward (DeepSpeed handles gradient scaling/accumulation)
            model_engine.backward(loss)

            # Step
            model_engine.step()

            if step % 100 == 0:
                print(f"Step {step}, Loss: {loss.item():.4f}")

    # Save checkpoint
    model_engine.save_checkpoint("checkpoint_dir")


# Execute
# deepspeed --num_gpus=8 train.py --deepspeed_config ds_config.json

4. Activation Checkpointing (Gradient Checkpointing)

4.1 Concept

Normal Forward:
Layer 1 β†’ [Save Act1] β†’ Layer 2 β†’ [Save Act2] β†’ ... β†’ Loss

Use Act1, Act2 during backward to compute gradients
β†’ Memory: O(L) - proportional to layer count

Activation Checkpointing:
Layer 1 β†’ [Checkpoint] β†’ Layer 2 β†’ Layer 3 β†’ [Checkpoint] β†’ ... β†’ Loss

Recompute from checkpoints during backward
β†’ Memory: O(√L) - square root of layer count
β†’ Computation: ~33% increase (recomputation cost)

4.2 Implementation

import torch
from torch.utils.checkpoint import checkpoint, checkpoint_sequential

class TransformerBlockWithCheckpoint(nn.Module):
    """Transformer block with checkpointing"""

    def __init__(self, config, use_checkpoint=True):
        super().__init__()
        self.use_checkpoint = use_checkpoint

        self.attention = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.norm2 = nn.LayerNorm(config.hidden_size)

    def forward(self, x, attention_mask=None):
        if self.use_checkpoint and self.training:
            # Use checkpointing
            return checkpoint(
                self._forward_impl,
                x, attention_mask,
                use_reentrant=False,  # PyTorch 2.0+ recommended
            )
        else:
            return self._forward_impl(x, attention_mask)

    def _forward_impl(self, x, attention_mask):
        # Attention
        residual = x
        x = self.norm1(x)
        x = self.attention(x, attention_mask)
        x = residual + x

        # FFN
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x

        return x


class TransformerWithSelectiveCheckpoint(nn.Module):
    """Selective Checkpointing"""

    def __init__(self, config, checkpoint_ratio=0.5):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlockWithCheckpoint(
                config,
                # Only checkpoint some layers
                use_checkpoint=(i % int(1/checkpoint_ratio) == 0)
            )
            for i in range(config.num_layers)
        ])

    def forward(self, x, attention_mask=None):
        for layer in self.layers:
            x = layer(x, attention_mask)
        return x

5. Training Stability

5.1 Loss Spike Response

class TrainingStabilizer:
    """Training stability management"""

    def __init__(
        self,
        loss_spike_threshold: float = 5.0,  # 5x compared to previous
        grad_norm_threshold: float = 10.0,
        window_size: int = 100
    ):
        self.loss_spike_threshold = loss_spike_threshold
        self.grad_norm_threshold = grad_norm_threshold
        self.window_size = window_size

        self.loss_history = []
        self.grad_norm_history = []
        self.skipped_steps = 0

    def check_loss_spike(self, loss: float) -> bool:
        """Detect loss spike"""
        if len(self.loss_history) < self.window_size:
            self.loss_history.append(loss)
            return False

        avg_loss = sum(self.loss_history[-self.window_size:]) / self.window_size

        if loss > avg_loss * self.loss_spike_threshold:
            print(f"⚠️ Loss spike detected: {loss:.4f} (avg: {avg_loss:.4f})")
            return True

        self.loss_history.append(loss)
        return False

    def check_grad_norm(self, model: nn.Module) -> tuple[float, bool]:
        """Check gradient norm"""
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        is_spike = total_norm > self.grad_norm_threshold

        if is_spike:
            print(f"⚠️ Gradient spike: {total_norm:.4f}")

        self.grad_norm_history.append(total_norm)
        return total_norm, is_spike

    def should_skip_step(self, loss: float, model: nn.Module) -> bool:
        """Decide whether to skip this step"""
        loss_spike = self.check_loss_spike(loss)
        _, grad_spike = self.check_grad_norm(model)

        if loss_spike or grad_spike:
            self.skipped_steps += 1
            return True

        return False


def stable_training_step(
    model, batch, optimizer, stabilizer, scaler=None
):
    """Stable training step"""

    # Forward
    with torch.cuda.amp.autocast():
        outputs = model(**batch)
        loss = outputs.loss

    # Check loss spike
    if stabilizer.should_skip_step(loss.item(), model):
        optimizer.zero_grad()
        print(f"Skipping step (total skipped: {stabilizer.skipped_steps})")
        return None

    # Backward
    if scaler:
        scaler.scale(loss).backward()

        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    optimizer.zero_grad()

    return loss.item()

5.2 Checkpoint Strategy

import os
import shutil
from datetime import datetime

class CheckpointManager:
    """Checkpoint management"""

    def __init__(
        self,
        save_dir: str,
        max_checkpoints: int = 5,
        save_interval_steps: int = 1000,
        save_interval_hours: float = 1.0
    ):
        self.save_dir = save_dir
        self.max_checkpoints = max_checkpoints
        self.save_interval_steps = save_interval_steps
        self.save_interval_hours = save_interval_hours

        self.last_save_time = datetime.now()
        self.checkpoints = []

        os.makedirs(save_dir, exist_ok=True)

    def should_save(self, step: int) -> bool:
        """Decide whether to save checkpoint"""
        # Step-based
        if step % self.save_interval_steps == 0:
            return True

        # Time-based
        elapsed = (datetime.now() - self.last_save_time).total_seconds() / 3600
        if elapsed >= self.save_interval_hours:
            return True

        return False

    def save(
        self,
        model,
        optimizer,
        scheduler,
        step: int,
        loss: float,
        **extra
    ):
        """Save checkpoint"""
        checkpoint_name = f"checkpoint-{step}"
        checkpoint_path = os.path.join(self.save_dir, checkpoint_name)

        # Save
        state = {
            'step': step,
            'loss': loss,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            **extra
        }

        torch.save(state, checkpoint_path + ".pt")

        # Metadata
        self.checkpoints.append({
            'path': checkpoint_path,
            'step': step,
            'loss': loss,
            'time': datetime.now().isoformat()
        })

        self.last_save_time = datetime.now()

        # Remove old checkpoints
        self._cleanup()

        print(f"πŸ’Ύ Saved checkpoint: {checkpoint_name}")

    def _cleanup(self):
        """Clean up old checkpoints"""
        while len(self.checkpoints) > self.max_checkpoints:
            oldest = self.checkpoints.pop(0)
            if os.path.exists(oldest['path'] + ".pt"):
                os.remove(oldest['path'] + ".pt")
                print(f"πŸ—‘οΈ Removed old checkpoint: {oldest['path']}")

    def load_latest(self) -> dict:
        """Load latest checkpoint"""
        if not self.checkpoints:
            # Find in directory
            files = sorted([
                f for f in os.listdir(self.save_dir)
                if f.startswith("checkpoint-") and f.endswith(".pt")
            ])

            if not files:
                return None

            latest = files[-1]
            return torch.load(os.path.join(self.save_dir, latest))

        return torch.load(self.checkpoints[-1]['path'] + ".pt")

6. Learning Rate Scheduling

6.1 Warmup + Cosine Decay

import math
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    min_lr_ratio: float = 0.1,
    num_cycles: float = 0.5
):
    """
    Warmup + Cosine Decay scheduler

    Early training: Linear warmup (0 β†’ max_lr)
    After: Cosine decay (max_lr β†’ min_lr)
    """

    def lr_lambda(current_step):
        # Warmup
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))

        # Cosine decay
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )

        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))

        # Decay only to min_lr
        decayed = (1 - min_lr_ratio) * cosine_decay + min_lr_ratio

        return decayed

    return LambdaLR(optimizer, lr_lambda)


# WSD (Warmup-Stable-Decay) scheduler (Llama 2)
def get_wsd_schedule(
    optimizer,
    num_warmup_steps: int,
    num_stable_steps: int,
    num_decay_steps: int,
    min_lr_ratio: float = 0.1
):
    """
    Warmup-Stable-Decay scheduler

    1. Warmup: 0 β†’ max_lr
    2. Stable: maintain max_lr
    3. Decay: max_lr β†’ min_lr (cosine)
    """
    total_steps = num_warmup_steps + num_stable_steps + num_decay_steps

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            # Warmup phase
            return float(current_step) / float(max(1, num_warmup_steps))

        elif current_step < num_warmup_steps + num_stable_steps:
            # Stable phase
            return 1.0

        else:
            # Decay phase
            decay_step = current_step - num_warmup_steps - num_stable_steps
            progress = float(decay_step) / float(max(1, num_decay_steps))
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
            return (1 - min_lr_ratio) * cosine_decay + min_lr_ratio

    return LambdaLR(optimizer, lr_lambda)

7. Practice: Complete Training Script

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import wandb

def main():
    """Complete distributed training script"""

    # 1. Initialize distributed
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()
    torch.cuda.set_device(local_rank)

    # Only rank 0 logs
    is_main = local_rank == 0

    if is_main:
        wandb.init(project="foundation-model-training")

    # 2. Configuration
    config = {
        'hidden_size': 4096,
        'num_layers': 32,
        'num_heads': 32,
        'vocab_size': 50257,
        'max_seq_len': 2048,
        'batch_size': 4,  # per GPU
        'gradient_accumulation': 8,
        'learning_rate': 3e-4,
        'warmup_steps': 2000,
        'total_steps': 100000,
        'weight_decay': 0.1,
        'max_grad_norm': 1.0,
    }

    effective_batch = config['batch_size'] * config['gradient_accumulation'] * world_size
    print(f"Effective batch size: {effective_batch}")

    # 3. Model
    model = TransformerModel(config).cuda()

    # Activation checkpointing
    model.gradient_checkpointing_enable()

    # DDP or FSDP
    model = DDP(model, device_ids=[local_rank])

    # 4. Data
    dataset = PretrainingDataset(config)
    sampler = DistributedSampler(dataset, shuffle=True)
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
    )

    # 5. Optimizer & Scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.95),
    )

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config['warmup_steps'],
        num_training_steps=config['total_steps'],
    )

    # 6. Utilities
    scaler = torch.cuda.amp.GradScaler()
    stabilizer = TrainingStabilizer()
    checkpoint_mgr = CheckpointManager("checkpoints")

    # Resume from checkpoint
    checkpoint = checkpoint_mgr.load_latest()
    start_step = 0
    if checkpoint:
        model.module.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if checkpoint['scheduler_state_dict']:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_step = checkpoint['step']
        if is_main:
            print(f"Resumed from step {start_step}")

    # 7. Training loop
    model.train()
    global_step = start_step
    accumulated_loss = 0.0

    for epoch in range(100):  # Sufficiently large number
        sampler.set_epoch(epoch)

        for batch_idx, batch in enumerate(dataloader):
            batch = {k: v.cuda() for k, v in batch.items()}

            # Forward (Mixed Precision)
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                outputs = model(**batch)
                loss = outputs.loss / config['gradient_accumulation']

            # Backward
            scaler.scale(loss).backward()
            accumulated_loss += loss.item()

            # Gradient Accumulation
            if (batch_idx + 1) % config['gradient_accumulation'] == 0:
                # Gradient clipping
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    config['max_grad_norm']
                )

                # Step
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

                global_step += 1

                # Logging
                if is_main and global_step % 10 == 0:
                    lr = scheduler.get_last_lr()[0]
                    wandb.log({
                        'loss': accumulated_loss,
                        'learning_rate': lr,
                        'grad_norm': grad_norm.item(),
                        'step': global_step,
                    })
                    print(f"Step {global_step}: loss={accumulated_loss:.4f}, lr={lr:.2e}")

                accumulated_loss = 0.0

                # Checkpoint
                if checkpoint_mgr.should_save(global_step):
                    if is_main:
                        checkpoint_mgr.save(
                            model.module, optimizer, scheduler,
                            global_step, accumulated_loss
                        )

                # Termination condition
                if global_step >= config['total_steps']:
                    break

        if global_step >= config['total_steps']:
            break

    # Cleanup
    dist.destroy_process_group()
    if is_main:
        wandb.finish()


if __name__ == "__main__":
    main()

# Execute:
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=0 \
#          --master_addr="master" --master_port=29500 train.py

References

Documentation

Papers

  • Rajbhandari et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models"
  • Narayanan et al. (2021). "Efficient Large-Scale Language Model Training on GPU Clusters"
to navigate between lessons