06. Pre-training Infrastructure
06. Pre-training Infrastructure¶
Overview¶
Training large-scale Foundation Models runs on thousands of GPUs for weeks to months. This lesson covers distributed training strategies, memory optimization, and training stability techniques.
1. Distributed Training Paradigms¶
1.1 Parallelization Strategy Overview¶
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Distributed Training Paradigms β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Data Parallelism (DP) Tensor Parallelism (TP) β
β βββββββ βββββββ ββββββββββββββββββββ β
β βGPU 0β βGPU 1β β W = [W1 | W2] β β
β βModelβ βModelβ βGPU0 GPU1 β β
β βData1β βData2β β W1 W2 β β
β βββββββ βββββββ ββββββββββββββββββββ β
β Same model, different data Split layers across GPUs β
β β
β Pipeline Parallelism (PP) Sequence Parallelism (SP) β
β βββββββ βββββββ ββββββ¬βββββ¬βββββ¬βββββ β
β βGPU 0β βGPU 1β β S1 β S2 β S3 β S4 β β
β βL1-L6βββL7-12β βGPU0βGPU1βGPU2βGPU3β β
β βββββββ βββββββ ββββββ΄βββββ΄βββββ΄βββββ β
β Sequential layer split Split sequence across GPUs β
β β
β 3D Parallelism: DP + TP + PP combination β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 Memory Analysis¶
def estimate_training_memory(
num_params: int, # Number of parameters
batch_size: int,
seq_len: int,
hidden_dim: int,
num_layers: int,
dtype_bytes: int = 2, # fp16/bf16 = 2, fp32 = 4
optimizer: str = 'adam'
) -> dict:
"""
Estimate GPU memory during training
Memory components:
1. Model Parameters
2. Gradients
3. Optimizer States
4. Activations (forward pass)
"""
# 1. Model parameters
param_memory = num_params * dtype_bytes
# 2. Gradients (same as parameters)
grad_memory = num_params * dtype_bytes
# 3. Optimizer States
if optimizer == 'adam':
# Adam: momentum(fp32) + variance(fp32)
optimizer_memory = num_params * 4 * 2 # 8 bytes per param
elif optimizer == 'sgd':
optimizer_memory = num_params * 4 # momentum only
else:
optimizer_memory = 0
# 4. Activations (approximation)
# Per layer: attention + FFN activations
bytes_per_token = hidden_dim * dtype_bytes * 10 # approximation
activation_memory = batch_size * seq_len * bytes_per_token * num_layers
# Activation checkpointing reduces to 1/sqrt(L)
total = param_memory + grad_memory + optimizer_memory + activation_memory
return {
'parameters_gb': param_memory / 1e9,
'gradients_gb': grad_memory / 1e9,
'optimizer_gb': optimizer_memory / 1e9,
'activations_gb': activation_memory / 1e9,
'total_gb': total / 1e9
}
# Example: 7B model
memory = estimate_training_memory(
num_params=7e9,
batch_size=4,
seq_len=2048,
hidden_dim=4096,
num_layers=32
)
print("7B model memory estimate:")
for key, value in memory.items():
print(f" {key}: {value:.1f} GB")
# Output:
# parameters_gb: 14.0 GB
# gradients_gb: 14.0 GB
# optimizer_gb: 56.0 GB
# activations_gb: ~21.5 GB (batch_size=4)
# total_gb: ~105.5 GB
2. FSDP (Fully Sharded Data Parallel)¶
2.1 FSDP Concept¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β FSDP Operating Principle β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Traditional DDP: β
β GPU 0: [Full Model] + [Data 0] β
β GPU 1: [Full Model] + [Data 1] β
β β Full model replicated on each GPU (inefficient) β
β β
β FSDP (Zero Stage 3): β
β GPU 0: [Shard 0] + [Data 0] β
β GPU 1: [Shard 1] + [Data 1] β
β β
β Forward: All-Gather to collect full parameters β
β Backward: Reduce-Scatter to distribute gradients β
β β
β Memory: (Params + Grads + Optim) / N + Activations β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2.2 PyTorch FSDP Implementation¶
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
)
import functools
def setup_fsdp_training():
"""Setup FSDP training"""
# Initialize distributed
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Create model
model = MyTransformerModel(config)
# Mixed Precision settings
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16, # Parameters
reduce_dtype=torch.bfloat16, # Gradient reduction
buffer_dtype=torch.bfloat16, # Buffers
)
# Auto Wrap Policy: Shard at Transformer layer level
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
# FSDP wrapping
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # Zero-3
mixed_precision=mixed_precision,
auto_wrap_policy=wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
cpu_offload=CPUOffload(offload_params=False),
device_id=local_rank,
)
return model
def train_step_fsdp(model, batch, optimizer, scaler=None):
"""FSDP training step"""
model.train()
# Forward
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss
# Backward
loss.backward()
# Gradient clipping (requires care with FSDP)
model.clip_grad_norm_(max_norm=1.0)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
return loss.item()
# Checkpoint save/load
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
def save_fsdp_checkpoint(model, optimizer, path):
"""Save FSDP checkpoint"""
# Full State Dict config
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True,
rank0_only=True,
)
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
full_state_dict_config,
):
state_dict = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
if dist.get_rank() == 0:
torch.save({
'model': state_dict,
'optimizer': optim_state,
}, path)
dist.barrier()
3. DeepSpeed ZeRO¶
3.1 ZeRO Stage Comparison¶
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β DeepSpeed ZeRO Stages β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Stage 1: Optimizer State Partitioning β
β - Only optimizer states (Adam m, v) partitioned β
β - Memory savings: ~4x β
β β
β Stage 2: + Gradient Partitioning β
β - Gradients also partitioned β
β - Memory savings: ~8x β
β β
β Stage 3: + Parameter Partitioning β
β - Parameters also partitioned (similar to FSDP) β
β - Memory savings: ~N (proportional to GPU count) β
β β
β ZeRO-Offload: Offload to CPU/NVMe β
β ZeRO-Infinity: Support for infinite model size β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
3.2 DeepSpeed Configuration¶
# ds_config.json
ds_config = {
"train_batch_size": 256,
"gradient_accumulation_steps": 8,
"train_micro_batch_size_per_gpu": 4,
# FP16 settings
"fp16": {
"enabled": True,
"loss_scale": 0, # dynamic
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
# BF16 settings (alternative)
"bf16": {
"enabled": False
},
# ZeRO Stage 3
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu", # or "nvme"
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
},
# Gradient Checkpointing
"activation_checkpointing": {
"partition_activations": True,
"cpu_checkpointing": True,
"contiguous_memory_optimization": True,
"number_checkpoints": None,
"synchronize_checkpoint_boundary": False,
"profile": False
},
# Optimizer
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
# Scheduler
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-4,
"warmup_num_steps": 1000,
"total_num_steps": 100000
}
}
}
3.3 DeepSpeed Training Code¶
import deepspeed
import torch
def train_with_deepspeed():
"""DeepSpeed training loop"""
# Model and data
model = MyTransformerModel(config)
train_dataloader = create_dataloader(...)
# DeepSpeed initialization
model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=ds_config,
)
# Training loop
for epoch in range(num_epochs):
for step, batch in enumerate(train_dataloader):
batch = {k: v.to(model_engine.device) for k, v in batch.items()}
# Forward
outputs = model_engine(**batch)
loss = outputs.loss
# Backward (DeepSpeed handles gradient scaling/accumulation)
model_engine.backward(loss)
# Step
model_engine.step()
if step % 100 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")
# Save checkpoint
model_engine.save_checkpoint("checkpoint_dir")
# Execute
# deepspeed --num_gpus=8 train.py --deepspeed_config ds_config.json
4. Activation Checkpointing (Gradient Checkpointing)¶
4.1 Concept¶
Normal Forward:
Layer 1 β [Save Act1] β Layer 2 β [Save Act2] β ... β Loss
Use Act1, Act2 during backward to compute gradients
β Memory: O(L) - proportional to layer count
Activation Checkpointing:
Layer 1 β [Checkpoint] β Layer 2 β Layer 3 β [Checkpoint] β ... β Loss
Recompute from checkpoints during backward
β Memory: O(βL) - square root of layer count
β Computation: ~33% increase (recomputation cost)
4.2 Implementation¶
import torch
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class TransformerBlockWithCheckpoint(nn.Module):
"""Transformer block with checkpointing"""
def __init__(self, config, use_checkpoint=True):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attention = MultiHeadAttention(config)
self.ffn = FeedForward(config)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
def forward(self, x, attention_mask=None):
if self.use_checkpoint and self.training:
# Use checkpointing
return checkpoint(
self._forward_impl,
x, attention_mask,
use_reentrant=False, # PyTorch 2.0+ recommended
)
else:
return self._forward_impl(x, attention_mask)
def _forward_impl(self, x, attention_mask):
# Attention
residual = x
x = self.norm1(x)
x = self.attention(x, attention_mask)
x = residual + x
# FFN
residual = x
x = self.norm2(x)
x = self.ffn(x)
x = residual + x
return x
class TransformerWithSelectiveCheckpoint(nn.Module):
"""Selective Checkpointing"""
def __init__(self, config, checkpoint_ratio=0.5):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlockWithCheckpoint(
config,
# Only checkpoint some layers
use_checkpoint=(i % int(1/checkpoint_ratio) == 0)
)
for i in range(config.num_layers)
])
def forward(self, x, attention_mask=None):
for layer in self.layers:
x = layer(x, attention_mask)
return x
5. Training Stability¶
5.1 Loss Spike Response¶
class TrainingStabilizer:
"""Training stability management"""
def __init__(
self,
loss_spike_threshold: float = 5.0, # 5x compared to previous
grad_norm_threshold: float = 10.0,
window_size: int = 100
):
self.loss_spike_threshold = loss_spike_threshold
self.grad_norm_threshold = grad_norm_threshold
self.window_size = window_size
self.loss_history = []
self.grad_norm_history = []
self.skipped_steps = 0
def check_loss_spike(self, loss: float) -> bool:
"""Detect loss spike"""
if len(self.loss_history) < self.window_size:
self.loss_history.append(loss)
return False
avg_loss = sum(self.loss_history[-self.window_size:]) / self.window_size
if loss > avg_loss * self.loss_spike_threshold:
print(f"β οΈ Loss spike detected: {loss:.4f} (avg: {avg_loss:.4f})")
return True
self.loss_history.append(loss)
return False
def check_grad_norm(self, model: nn.Module) -> tuple[float, bool]:
"""Check gradient norm"""
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
is_spike = total_norm > self.grad_norm_threshold
if is_spike:
print(f"β οΈ Gradient spike: {total_norm:.4f}")
self.grad_norm_history.append(total_norm)
return total_norm, is_spike
def should_skip_step(self, loss: float, model: nn.Module) -> bool:
"""Decide whether to skip this step"""
loss_spike = self.check_loss_spike(loss)
_, grad_spike = self.check_grad_norm(model)
if loss_spike or grad_spike:
self.skipped_steps += 1
return True
return False
def stable_training_step(
model, batch, optimizer, stabilizer, scaler=None
):
"""Stable training step"""
# Forward
with torch.cuda.amp.autocast():
outputs = model(**batch)
loss = outputs.loss
# Check loss spike
if stabilizer.should_skip_step(loss.item(), model):
optimizer.zero_grad()
print(f"Skipping step (total skipped: {stabilizer.skipped_steps})")
return None
# Backward
if scaler:
scaler.scale(loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return loss.item()
5.2 Checkpoint Strategy¶
import os
import shutil
from datetime import datetime
class CheckpointManager:
"""Checkpoint management"""
def __init__(
self,
save_dir: str,
max_checkpoints: int = 5,
save_interval_steps: int = 1000,
save_interval_hours: float = 1.0
):
self.save_dir = save_dir
self.max_checkpoints = max_checkpoints
self.save_interval_steps = save_interval_steps
self.save_interval_hours = save_interval_hours
self.last_save_time = datetime.now()
self.checkpoints = []
os.makedirs(save_dir, exist_ok=True)
def should_save(self, step: int) -> bool:
"""Decide whether to save checkpoint"""
# Step-based
if step % self.save_interval_steps == 0:
return True
# Time-based
elapsed = (datetime.now() - self.last_save_time).total_seconds() / 3600
if elapsed >= self.save_interval_hours:
return True
return False
def save(
self,
model,
optimizer,
scheduler,
step: int,
loss: float,
**extra
):
"""Save checkpoint"""
checkpoint_name = f"checkpoint-{step}"
checkpoint_path = os.path.join(self.save_dir, checkpoint_name)
# Save
state = {
'step': step,
'loss': loss,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
**extra
}
torch.save(state, checkpoint_path + ".pt")
# Metadata
self.checkpoints.append({
'path': checkpoint_path,
'step': step,
'loss': loss,
'time': datetime.now().isoformat()
})
self.last_save_time = datetime.now()
# Remove old checkpoints
self._cleanup()
print(f"πΎ Saved checkpoint: {checkpoint_name}")
def _cleanup(self):
"""Clean up old checkpoints"""
while len(self.checkpoints) > self.max_checkpoints:
oldest = self.checkpoints.pop(0)
if os.path.exists(oldest['path'] + ".pt"):
os.remove(oldest['path'] + ".pt")
print(f"ποΈ Removed old checkpoint: {oldest['path']}")
def load_latest(self) -> dict:
"""Load latest checkpoint"""
if not self.checkpoints:
# Find in directory
files = sorted([
f for f in os.listdir(self.save_dir)
if f.startswith("checkpoint-") and f.endswith(".pt")
])
if not files:
return None
latest = files[-1]
return torch.load(os.path.join(self.save_dir, latest))
return torch.load(self.checkpoints[-1]['path'] + ".pt")
6. Learning Rate Scheduling¶
6.1 Warmup + Cosine Decay¶
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.1,
num_cycles: float = 0.5
):
"""
Warmup + Cosine Decay scheduler
Early training: Linear warmup (0 β max_lr)
After: Cosine decay (max_lr β min_lr)
"""
def lr_lambda(current_step):
# Warmup
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
# Cosine decay
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
# Decay only to min_lr
decayed = (1 - min_lr_ratio) * cosine_decay + min_lr_ratio
return decayed
return LambdaLR(optimizer, lr_lambda)
# WSD (Warmup-Stable-Decay) scheduler (Llama 2)
def get_wsd_schedule(
optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
min_lr_ratio: float = 0.1
):
"""
Warmup-Stable-Decay scheduler
1. Warmup: 0 β max_lr
2. Stable: maintain max_lr
3. Decay: max_lr β min_lr (cosine)
"""
total_steps = num_warmup_steps + num_stable_steps + num_decay_steps
def lr_lambda(current_step):
if current_step < num_warmup_steps:
# Warmup phase
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step < num_warmup_steps + num_stable_steps:
# Stable phase
return 1.0
else:
# Decay phase
decay_step = current_step - num_warmup_steps - num_stable_steps
progress = float(decay_step) / float(max(1, num_decay_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return (1 - min_lr_ratio) * cosine_decay + min_lr_ratio
return LambdaLR(optimizer, lr_lambda)
7. Practice: Complete Training Script¶
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import wandb
def main():
"""Complete distributed training script"""
# 1. Initialize distributed
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
# Only rank 0 logs
is_main = local_rank == 0
if is_main:
wandb.init(project="foundation-model-training")
# 2. Configuration
config = {
'hidden_size': 4096,
'num_layers': 32,
'num_heads': 32,
'vocab_size': 50257,
'max_seq_len': 2048,
'batch_size': 4, # per GPU
'gradient_accumulation': 8,
'learning_rate': 3e-4,
'warmup_steps': 2000,
'total_steps': 100000,
'weight_decay': 0.1,
'max_grad_norm': 1.0,
}
effective_batch = config['batch_size'] * config['gradient_accumulation'] * world_size
print(f"Effective batch size: {effective_batch}")
# 3. Model
model = TransformerModel(config).cuda()
# Activation checkpointing
model.gradient_checkpointing_enable()
# DDP or FSDP
model = DDP(model, device_ids=[local_rank])
# 4. Data
dataset = PretrainingDataset(config)
sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
dataset,
batch_size=config['batch_size'],
sampler=sampler,
num_workers=4,
pin_memory=True,
)
# 5. Optimizer & Scheduler
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay'],
betas=(0.9, 0.95),
)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=config['total_steps'],
)
# 6. Utilities
scaler = torch.cuda.amp.GradScaler()
stabilizer = TrainingStabilizer()
checkpoint_mgr = CheckpointManager("checkpoints")
# Resume from checkpoint
checkpoint = checkpoint_mgr.load_latest()
start_step = 0
if checkpoint:
model.module.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if checkpoint['scheduler_state_dict']:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_step = checkpoint['step']
if is_main:
print(f"Resumed from step {start_step}")
# 7. Training loop
model.train()
global_step = start_step
accumulated_loss = 0.0
for epoch in range(100): # Sufficiently large number
sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
# Forward (Mixed Precision)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss / config['gradient_accumulation']
# Backward
scaler.scale(loss).backward()
accumulated_loss += loss.item()
# Gradient Accumulation
if (batch_idx + 1) % config['gradient_accumulation'] == 0:
# Gradient clipping
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
config['max_grad_norm']
)
# Step
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
global_step += 1
# Logging
if is_main and global_step % 10 == 0:
lr = scheduler.get_last_lr()[0]
wandb.log({
'loss': accumulated_loss,
'learning_rate': lr,
'grad_norm': grad_norm.item(),
'step': global_step,
})
print(f"Step {global_step}: loss={accumulated_loss:.4f}, lr={lr:.2e}")
accumulated_loss = 0.0
# Checkpoint
if checkpoint_mgr.should_save(global_step):
if is_main:
checkpoint_mgr.save(
model.module, optimizer, scheduler,
global_step, accumulated_loss
)
# Termination condition
if global_step >= config['total_steps']:
break
if global_step >= config['total_steps']:
break
# Cleanup
dist.destroy_process_group()
if is_main:
wandb.finish()
if __name__ == "__main__":
main()
# Execute:
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=0 \
# --master_addr="master" --master_port=29500 train.py
References¶
Documentation¶
Papers¶
- Rajbhandari et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models"
- Narayanan et al. (2021). "Efficient Large-Scale Language Model Training on GPU Clusters"