06. Pre-training μΈνλΌ
06. Pre-training μΈνλΌ¶
κ°μ¶
λκ·λͺ¨ Foundation Model νμ΅μ μμ² κ°μ GPUμμ μμ£Όμμ μκ°μκ° μ§νλ©λλ€. μ΄ λ μ¨μμλ λΆμ° νμ΅ μ λ΅, λ©λͺ¨λ¦¬ μ΅μ ν, νμ΅ μμ μ± κΈ°λ²μ λ€λ£Ήλλ€.
1. λΆμ° νμ΅ ν¨λ¬λ€μ¶
1.1 λ³λ ¬ν μ λ΅ κ°μ¶
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β λΆμ° νμ΅ ν¨λ¬λ€μ β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Data Parallelism (DP) Tensor Parallelism (TP) β
β βββββββ βββββββ ββββββββββββββββββββ β
β βGPU 0β βGPU 1β β W = [W1 | W2] β β
β βModelβ βModelβ βGPU0 GPU1 β β
β βData1β βData2β β W1 W2 β β
β βββββββ βββββββ ββββββββββββββββββββ β
β λμΌ λͺ¨λΈ, λ€λ₯Έ λ°μ΄ν° λ μ΄μ΄λ₯Ό GPUκ° λΆν β
β β
β Pipeline Parallelism (PP) Sequence Parallelism (SP) β
β βββββββ βββββββ ββββββ¬βββββ¬βββββ¬βββββ β
β βGPU 0β βGPU 1β β S1 β S2 β S3 β S4 β β
β βL1-L6βββL7-12β βGPU0βGPU1βGPU2βGPU3β β
β βββββββ βββββββ ββββββ΄βββββ΄βββββ΄βββββ β
β λ μ΄μ΄λ₯Ό μμ°¨ λΆν μνμ€λ₯Ό GPUκ° λΆν β
β β
β 3D Parallelism: DP + TP + PP μ‘°ν© β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 λ©λͺ¨λ¦¬ λΆμ¶
def estimate_training_memory(
num_params: int, # νλΌλ―Έν° μ
batch_size: int,
seq_len: int,
hidden_dim: int,
num_layers: int,
dtype_bytes: int = 2, # fp16/bf16 = 2, fp32 = 4
optimizer: str = 'adam'
) -> dict:
"""
νμ΅ μ GPU λ©λͺ¨λ¦¬ μΆμ
λ©λͺ¨λ¦¬ ꡬμ±:
1. Model Parameters
2. Gradients
3. Optimizer States
4. Activations (forward pass)
"""
# 1. λͺ¨λΈ νλΌλ―Έν°
param_memory = num_params * dtype_bytes
# 2. Gradients (νλΌλ―Έν°μ λμΌ)
grad_memory = num_params * dtype_bytes
# 3. Optimizer States
if optimizer == 'adam':
# Adam: momentum(fp32) + variance(fp32)
optimizer_memory = num_params * 4 * 2 # 8 bytes per param
elif optimizer == 'sgd':
optimizer_memory = num_params * 4 # momentum only
else:
optimizer_memory = 0
# 4. Activations (κ·Όμ¬μΉ)
# κ° λ μ΄μ΄: attention + FFN activations
bytes_per_token = hidden_dim * dtype_bytes * 10 # κ·Όμ¬
activation_memory = batch_size * seq_len * bytes_per_token * num_layers
# Activation checkpointing μ 1/sqrt(L) λ‘ κ°μ
total = param_memory + grad_memory + optimizer_memory + activation_memory
return {
'parameters_gb': param_memory / 1e9,
'gradients_gb': grad_memory / 1e9,
'optimizer_gb': optimizer_memory / 1e9,
'activations_gb': activation_memory / 1e9,
'total_gb': total / 1e9
}
# μμ: 7B λͺ¨λΈ
memory = estimate_training_memory(
num_params=7e9,
batch_size=4,
seq_len=2048,
hidden_dim=4096,
num_layers=32
)
print("7B λͺ¨λΈ λ©λͺ¨λ¦¬ μΆμ :")
for key, value in memory.items():
print(f" {key}: {value:.1f} GB")
# μΆλ ₯:
# parameters_gb: 14.0 GB
# gradients_gb: 14.0 GB
# optimizer_gb: 56.0 GB
# activations_gb: ~21.5 GB (batch_size=4)
# total_gb: ~105.5 GB
2. FSDP (Fully Sharded Data Parallel)¶
2.1 FSDP κ°λ ¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β FSDP λμ μ리 β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β κΈ°μ‘΄ DDP: β
β GPU 0: [Full Model] + [Data 0] β
β GPU 1: [Full Model] + [Data 1] β
β β κ° GPUμ μ 체 λͺ¨λΈ 볡μ (λΉν¨μ¨) β
β β
β FSDP (Zero Stage 3): β
β GPU 0: [Shard 0] + [Data 0] β
β GPU 1: [Shard 1] + [Data 1] β
β β
β Forward μ: All-Gatherλ‘ μ 체 νλΌλ―Έν° μμ§ β
β Backward μ: Reduce-Scatterλ‘ gradient λΆμ° β
β β
β λ©λͺ¨λ¦¬: (Params + Grads + Optim) / N + Activations β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2.2 PyTorch FSDP ꡬν¶
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
)
import functools
def setup_fsdp_training():
"""FSDP νμ΅ μ€μ """
# λΆμ° μ΄κΈ°ν
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# λͺ¨λΈ μμ±
model = MyTransformerModel(config)
# Mixed Precision μ€μ
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16, # νλΌλ―Έν°
reduce_dtype=torch.bfloat16, # gradient reduction
buffer_dtype=torch.bfloat16, # λ²νΌ
)
# Auto Wrap Policy: Transformer λ μ΄μ΄ λ¨μλ‘ μ€λ©
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
# FSDP λν
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # Zero-3
mixed_precision=mixed_precision,
auto_wrap_policy=wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
cpu_offload=CPUOffload(offload_params=False),
device_id=local_rank,
)
return model
def train_step_fsdp(model, batch, optimizer, scaler=None):
"""FSDP νμ΅ μ€ν
"""
model.train()
# Forward
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss
# Backward
loss.backward()
# Gradient clipping (FSDPμμλ μ£Όμ νμ)
model.clip_grad_norm_(max_norm=1.0)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
return loss.item()
# 체ν¬ν¬μΈνΈ μ μ₯/λ‘λ
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
def save_fsdp_checkpoint(model, optimizer, path):
"""FSDP 체ν¬ν¬μΈνΈ μ μ₯"""
# Full State Dict μ€μ
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True,
rank0_only=True,
)
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
full_state_dict_config,
):
state_dict = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
if dist.get_rank() == 0:
torch.save({
'model': state_dict,
'optimizer': optim_state,
}, path)
dist.barrier()
3. DeepSpeed ZeRO¶
3.1 ZeRO λ¨κ³λ³ λΉκ΅¶
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β DeepSpeed ZeRO λ¨κ³ β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Stage 1: Optimizer State Partitioning β
β - Optimizer states (Adam m, v)λ§ λΆν β
β - λ©λͺ¨λ¦¬ μ κ°: ~4x β
β β
β Stage 2: + Gradient Partitioning β
β - Gradientsλ λΆν β
β - λ©λͺ¨λ¦¬ μ κ°: ~8x β
β β
β Stage 3: + Parameter Partitioning β
β - Parametersλ λΆν (FSDPμ μ μ¬) β
β - λ©λͺ¨λ¦¬ μ κ°: ~N (GPU μμ λΉλ‘) β
β β
β ZeRO-Offload: CPU/NVMeλ‘ μ€νλ‘λ β
β ZeRO-Infinity: 무ν λͺ¨λΈ ν¬κΈ° μ§μ β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
3.2 DeepSpeed μ€μ ¶
# ds_config.json
ds_config = {
"train_batch_size": 256,
"gradient_accumulation_steps": 8,
"train_micro_batch_size_per_gpu": 4,
# FP16 μ€μ
"fp16": {
"enabled": True,
"loss_scale": 0, # dynamic
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
# BF16 μ€μ (λμ)
"bf16": {
"enabled": False
},
# ZeRO Stage 3
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu", # or "nvme"
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
},
# Gradient Checkpointing
"activation_checkpointing": {
"partition_activations": True,
"cpu_checkpointing": True,
"contiguous_memory_optimization": True,
"number_checkpoints": None,
"synchronize_checkpoint_boundary": False,
"profile": False
},
# Optimizer
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
# Scheduler
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-4,
"warmup_num_steps": 1000,
"total_num_steps": 100000
}
}
}
3.3 DeepSpeed νμ΅ μ½λ¶
import deepspeed
import torch
def train_with_deepspeed():
"""DeepSpeed νμ΅ λ£¨ν"""
# λͺ¨λΈ λ° λ°μ΄ν°
model = MyTransformerModel(config)
train_dataloader = create_dataloader(...)
# DeepSpeed μ΄κΈ°ν
model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=ds_config,
)
# νμ΅ λ£¨ν
for epoch in range(num_epochs):
for step, batch in enumerate(train_dataloader):
batch = {k: v.to(model_engine.device) for k, v in batch.items()}
# Forward
outputs = model_engine(**batch)
loss = outputs.loss
# Backward (DeepSpeedκ° gradient scaling/accumulation μ²λ¦¬)
model_engine.backward(loss)
# Step
model_engine.step()
if step % 100 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")
# 체ν¬ν¬μΈνΈ μ μ₯
model_engine.save_checkpoint("checkpoint_dir")
# μ€ν
# deepspeed --num_gpus=8 train.py --deepspeed_config ds_config.json
4. Activation Checkpointing (Gradient Checkpointing)¶
4.1 κ°λ ¶
μΌλ° Forward:
Layer 1 β [Act1 μ μ₯] β Layer 2 β [Act2 μ μ₯] β ... β Loss
Backward μ Act1, Act2 λ±μ μ¬μ©νμ¬ gradient κ³μ°
β λ©λͺ¨λ¦¬: O(L) - λ μ΄μ΄ μμ λΉλ‘
Activation Checkpointing:
Layer 1 β [체ν¬ν¬μΈνΈ] β Layer 2 β Layer 3 β [체ν¬ν¬μΈνΈ] β ... β Loss
Backward μ 체ν¬ν¬μΈνΈμμ μ¬κ³μ°
β λ©λͺ¨λ¦¬: O(βL) - λ£¨νΈ λ μ΄μ΄ μ
β κ³μ°: ~33% μ¦κ° (μ¬κ³μ° λΉμ©)
4.2 ꡬν¶
import torch
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class TransformerBlockWithCheckpoint(nn.Module):
"""Checkpointingμ΄ μ μ©λ Transformer λΈλ‘"""
def __init__(self, config, use_checkpoint=True):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attention = MultiHeadAttention(config)
self.ffn = FeedForward(config)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
def forward(self, x, attention_mask=None):
if self.use_checkpoint and self.training:
# Checkpointing μ¬μ©
return checkpoint(
self._forward_impl,
x, attention_mask,
use_reentrant=False, # PyTorch 2.0+ κΆμ₯
)
else:
return self._forward_impl(x, attention_mask)
def _forward_impl(self, x, attention_mask):
# Attention
residual = x
x = self.norm1(x)
x = self.attention(x, attention_mask)
x = residual + x
# FFN
residual = x
x = self.norm2(x)
x = self.ffn(x)
x = residual + x
return x
class TransformerWithSelectiveCheckpoint(nn.Module):
"""μ νμ Checkpointing"""
def __init__(self, config, checkpoint_ratio=0.5):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlockWithCheckpoint(
config,
# μΌλΆ λ μ΄μ΄λ§ checkpoint
use_checkpoint=(i % int(1/checkpoint_ratio) == 0)
)
for i in range(config.num_layers)
])
def forward(self, x, attention_mask=None):
for layer in self.layers:
x = layer(x, attention_mask)
return x
5. νμ΅ μμ μ±¶
5.1 Loss Spike λμ¶
class TrainingStabilizer:
"""νμ΅ μμ μ± κ΄λ¦¬"""
def __init__(
self,
loss_spike_threshold: float = 5.0, # μ΄μ λλΉ 5λ°°
grad_norm_threshold: float = 10.0,
window_size: int = 100
):
self.loss_spike_threshold = loss_spike_threshold
self.grad_norm_threshold = grad_norm_threshold
self.window_size = window_size
self.loss_history = []
self.grad_norm_history = []
self.skipped_steps = 0
def check_loss_spike(self, loss: float) -> bool:
"""Loss spike κ°μ§"""
if len(self.loss_history) < self.window_size:
self.loss_history.append(loss)
return False
avg_loss = sum(self.loss_history[-self.window_size:]) / self.window_size
if loss > avg_loss * self.loss_spike_threshold:
print(f"β οΈ Loss spike detected: {loss:.4f} (avg: {avg_loss:.4f})")
return True
self.loss_history.append(loss)
return False
def check_grad_norm(self, model: nn.Module) -> tuple[float, bool]:
"""Gradient norm 체ν¬"""
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
is_spike = total_norm > self.grad_norm_threshold
if is_spike:
print(f"β οΈ Gradient spike: {total_norm:.4f}")
self.grad_norm_history.append(total_norm)
return total_norm, is_spike
def should_skip_step(self, loss: float, model: nn.Module) -> bool:
"""ν΄λΉ stepμ 건λλΈμ§ κ²°μ """
loss_spike = self.check_loss_spike(loss)
_, grad_spike = self.check_grad_norm(model)
if loss_spike or grad_spike:
self.skipped_steps += 1
return True
return False
def stable_training_step(
model, batch, optimizer, stabilizer, scaler=None
):
"""μμ μ μΈ νμ΅ μ€ν
"""
# Forward
with torch.cuda.amp.autocast():
outputs = model(**batch)
loss = outputs.loss
# Loss spike 체ν¬
if stabilizer.should_skip_step(loss.item(), model):
optimizer.zero_grad()
print(f"Skipping step (total skipped: {stabilizer.skipped_steps})")
return None
# Backward
if scaler:
scaler.scale(loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return loss.item()
5.2 체ν¬ν¬μΈνΈ μ λ΅¶
import os
import shutil
from datetime import datetime
class CheckpointManager:
"""체ν¬ν¬μΈνΈ κ΄λ¦¬"""
def __init__(
self,
save_dir: str,
max_checkpoints: int = 5,
save_interval_steps: int = 1000,
save_interval_hours: float = 1.0
):
self.save_dir = save_dir
self.max_checkpoints = max_checkpoints
self.save_interval_steps = save_interval_steps
self.save_interval_hours = save_interval_hours
self.last_save_time = datetime.now()
self.checkpoints = []
os.makedirs(save_dir, exist_ok=True)
def should_save(self, step: int) -> bool:
"""체ν¬ν¬μΈνΈ μ μ₯ μ¬λΆ κ²°μ """
# μ€ν
κΈ°λ°
if step % self.save_interval_steps == 0:
return True
# μκ° κΈ°λ°
elapsed = (datetime.now() - self.last_save_time).total_seconds() / 3600
if elapsed >= self.save_interval_hours:
return True
return False
def save(
self,
model,
optimizer,
scheduler,
step: int,
loss: float,
**extra
):
"""체ν¬ν¬μΈνΈ μ μ₯"""
checkpoint_name = f"checkpoint-{step}"
checkpoint_path = os.path.join(self.save_dir, checkpoint_name)
# μ μ₯
state = {
'step': step,
'loss': loss,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
**extra
}
torch.save(state, checkpoint_path + ".pt")
# λ©νλ°μ΄ν°
self.checkpoints.append({
'path': checkpoint_path,
'step': step,
'loss': loss,
'time': datetime.now().isoformat()
})
self.last_save_time = datetime.now()
# μ€λλ 체ν¬ν¬μΈνΈ μμ
self._cleanup()
print(f"πΎ Saved checkpoint: {checkpoint_name}")
def _cleanup(self):
"""μ€λλ 체ν¬ν¬μΈνΈ μ 리"""
while len(self.checkpoints) > self.max_checkpoints:
oldest = self.checkpoints.pop(0)
if os.path.exists(oldest['path'] + ".pt"):
os.remove(oldest['path'] + ".pt")
print(f"ποΈ Removed old checkpoint: {oldest['path']}")
def load_latest(self) -> dict:
"""μ΅μ 체ν¬ν¬μΈνΈ λ‘λ"""
if not self.checkpoints:
# λλ ν 리μμ μ°ΎκΈ°
files = sorted([
f for f in os.listdir(self.save_dir)
if f.startswith("checkpoint-") and f.endswith(".pt")
])
if not files:
return None
latest = files[-1]
return torch.load(os.path.join(self.save_dir, latest))
return torch.load(self.checkpoints[-1]['path'] + ".pt")
6. νμ΅λ₯ μ€μΌμ€λ§¶
6.1 Warmup + Cosine Decay¶
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.1,
num_cycles: float = 0.5
):
"""
Warmup + Cosine Decay μ€μΌμ€λ¬
νμ΅ μ΄κΈ°: Linear warmup (0 β max_lr)
μ΄ν: Cosine decay (max_lr β min_lr)
"""
def lr_lambda(current_step):
# Warmup
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
# Cosine decay
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
# min_lrκΉμ§λ§ κ°μ
decayed = (1 - min_lr_ratio) * cosine_decay + min_lr_ratio
return decayed
return LambdaLR(optimizer, lr_lambda)
# WSD (Warmup-Stable-Decay) μ€μΌμ€λ¬ (Llama 2)
def get_wsd_schedule(
optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_decay_steps: int,
min_lr_ratio: float = 0.1
):
"""
Warmup-Stable-Decay μ€μΌμ€λ¬
1. Warmup: 0 β max_lr
2. Stable: max_lr μ μ§
3. Decay: max_lr β min_lr (cosine)
"""
total_steps = num_warmup_steps + num_stable_steps + num_decay_steps
def lr_lambda(current_step):
if current_step < num_warmup_steps:
# Warmup phase
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step < num_warmup_steps + num_stable_steps:
# Stable phase
return 1.0
else:
# Decay phase
decay_step = current_step - num_warmup_steps - num_stable_steps
progress = float(decay_step) / float(max(1, num_decay_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return (1 - min_lr_ratio) * cosine_decay + min_lr_ratio
return LambdaLR(optimizer, lr_lambda)
7. μ€μ΅: μμ ν νμ΅ μ€ν¬λ¦½νΈ¶
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import wandb
def main():
"""μμ ν λΆμ° νμ΅ μ€ν¬λ¦½νΈ"""
# 1. λΆμ° μ΄κΈ°ν
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
# Rank 0λ§ λ‘κΉ
is_main = local_rank == 0
if is_main:
wandb.init(project="foundation-model-training")
# 2. μ€μ
config = {
'hidden_size': 4096,
'num_layers': 32,
'num_heads': 32,
'vocab_size': 50257,
'max_seq_len': 2048,
'batch_size': 4, # per GPU
'gradient_accumulation': 8,
'learning_rate': 3e-4,
'warmup_steps': 2000,
'total_steps': 100000,
'weight_decay': 0.1,
'max_grad_norm': 1.0,
}
effective_batch = config['batch_size'] * config['gradient_accumulation'] * world_size
print(f"Effective batch size: {effective_batch}")
# 3. λͺ¨λΈ
model = TransformerModel(config).cuda()
# Activation checkpointing
model.gradient_checkpointing_enable()
# DDP λλ FSDP
model = DDP(model, device_ids=[local_rank])
# 4. λ°μ΄ν°
dataset = PretrainingDataset(config)
sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
dataset,
batch_size=config['batch_size'],
sampler=sampler,
num_workers=4,
pin_memory=True,
)
# 5. Optimizer & Scheduler
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay'],
betas=(0.9, 0.95),
)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=config['total_steps'],
)
# 6. μ νΈλ¦¬ν°
scaler = torch.cuda.amp.GradScaler()
stabilizer = TrainingStabilizer()
checkpoint_mgr = CheckpointManager("checkpoints")
# 체ν¬ν¬μΈνΈ 볡μ
checkpoint = checkpoint_mgr.load_latest()
start_step = 0
if checkpoint:
model.module.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if checkpoint['scheduler_state_dict']:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_step = checkpoint['step']
if is_main:
print(f"Resumed from step {start_step}")
# 7. νμ΅ λ£¨ν
model.train()
global_step = start_step
accumulated_loss = 0.0
for epoch in range(100): # μΆ©λΆν ν° μ
sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
# Forward (Mixed Precision)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss / config['gradient_accumulation']
# Backward
scaler.scale(loss).backward()
accumulated_loss += loss.item()
# Gradient Accumulation
if (batch_idx + 1) % config['gradient_accumulation'] == 0:
# Gradient clipping
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
config['max_grad_norm']
)
# Step
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
global_step += 1
# λ‘κΉ
if is_main and global_step % 10 == 0:
lr = scheduler.get_last_lr()[0]
wandb.log({
'loss': accumulated_loss,
'learning_rate': lr,
'grad_norm': grad_norm.item(),
'step': global_step,
})
print(f"Step {global_step}: loss={accumulated_loss:.4f}, lr={lr:.2e}")
accumulated_loss = 0.0
# 체ν¬ν¬μΈνΈ
if checkpoint_mgr.should_save(global_step):
if is_main:
checkpoint_mgr.save(
model.module, optimizer, scheduler,
global_step, accumulated_loss
)
# μ’
λ£ μ‘°κ±΄
if global_step >= config['total_steps']:
break
if global_step >= config['total_steps']:
break
# μ 리
dist.destroy_process_group()
if is_main:
wandb.finish()
if __name__ == "__main__":
main()
# μ€ν:
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=0 \
# --master_addr="master" --master_port=29500 train.py
μ°Έκ³ μλ£¶
λ¬Έμ¶
λ Όλ¬Έ¶
- Rajbhandari et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models"
- Narayanan et al. (2021). "Efficient Large-Scale Language Model Training on GPU Clusters"