Weights & Biases (W&B)
Weights & Biases (W&B)¶
1. W&B Overview¶
Weights & Biases is a platform for ML experiment tracking, hyperparameter tuning, and model management.
1.1 Core Features¶
┌─────────────────────────────────────────────────────────────────────┐
│ Weights & Biases Features │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Experiments │ │ Sweeps │ │ Artifacts │ │
│ │ │ │ │ │ │ │
│ │ - Experiment│ │ - Hyper │ │ - Datasets │ │
│ │ tracking │ │ parameter │ │ - Models │ │
│ │ - Metrics │ │ tuning │ │ - Version │ │
│ │ - Visualization│ │ │ │ control │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Tables │ │ Reports │ │ Models │ │
│ │ │ │ │ │ │ │
│ │ - Data │ │ - Documentation│ │ - Model │ │
│ │ visualization│ │ - Sharing │ │ registry │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
1.2 Installation and Setup¶
# Install
pip install wandb
# Login
wandb login
# Enter API key (https://wandb.ai/authorize)
# Set as environment variable
export WANDB_API_KEY=your-api-key
# Login from Python
import wandb
wandb.login(key="your-api-key")
2. Basic Experiment Tracking¶
2.1 First Experiment¶
"""
Basic W&B Usage
"""
import wandb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
# Prepare data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
# Initialize W&B
wandb.init(
project="iris-classification", # Project name
name="random-forest-baseline", # Run name
config={ # Hyperparameters
"n_estimators": 100,
"max_depth": 5,
"random_state": 42
},
tags=["baseline", "random-forest"],
notes="Initial baseline experiment"
)
# Access config
config = wandb.config
# Train model
model = RandomForestClassifier(
n_estimators=config.n_estimators,
max_depth=config.max_depth,
random_state=config.random_state
)
model.fit(X_train, y_train)
# Predict and evaluate
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# Log metrics
wandb.log({
"accuracy": accuracy,
"test_size": len(X_test),
"train_size": len(X_train)
})
# Finish run
wandb.finish()
2.2 Logging Training Process¶
"""
Real-time Training Process Logging
"""
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# Initialize
wandb.init(project="pytorch-training")
# Define model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.get("lr", 0.001))
# Track model graph in W&B
wandb.watch(model, criterion, log="all", log_freq=100)
# Training loop
for epoch in range(wandb.config.get("epochs", 10)):
model.train()
train_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Batch-level logging (optional)
if batch_idx % 100 == 0:
wandb.log({
"batch_loss": loss.item(),
"epoch": epoch,
"batch": batch_idx
})
# Epoch-level logging
avg_loss = train_loss / len(train_loader)
val_accuracy = evaluate(model, val_loader)
wandb.log({
"epoch": epoch,
"train_loss": avg_loss,
"val_accuracy": val_accuracy
})
# Save checkpoint
if val_accuracy > best_accuracy:
torch.save(model.state_dict(), "best_model.pth")
wandb.save("best_model.pth")
best_accuracy = val_accuracy
wandb.finish()
2.3 Logging Various Data Types¶
"""
Logging Various Data Types
"""
import wandb
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
wandb.init(project="data-logging-demo")
# 1. Log images
images = wandb.Image(
np.random.rand(100, 100, 3),
caption="Random Image"
)
wandb.log({"random_image": images})
# PIL image
pil_image = Image.open("sample.png")
wandb.log({"pil_image": wandb.Image(pil_image)})
# Multiple images
wandb.log({
"examples": [wandb.Image(img, caption=f"Sample {i}")
for i, img in enumerate(image_batch[:10])]
})
# 2. Log plots
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
ax.set_title("Training Curve")
wandb.log({"plot": wandb.Image(fig)})
plt.close()
# Or use plotly
import plotly.express as px
fig = px.scatter(x=[1, 2, 3], y=[1, 4, 9])
wandb.log({"plotly_chart": fig})
# 3. Histogram
wandb.log({"predictions": wandb.Histogram(predictions)})
# 4. Table
columns = ["id", "image", "prediction", "label"]
data = [
[i, wandb.Image(img), pred, label]
for i, (img, pred, label) in enumerate(zip(images, preds, labels))
]
table = wandb.Table(columns=columns, data=data)
wandb.log({"predictions_table": table})
# 5. Confusion Matrix
wandb.log({
"confusion_matrix": wandb.plot.confusion_matrix(
y_true=y_true,
preds=y_pred,
class_names=class_names
)
})
# 6. ROC Curve
wandb.log({
"roc_curve": wandb.plot.roc_curve(
y_true, y_scores, labels=class_names
)
})
# 7. PR Curve
wandb.log({
"pr_curve": wandb.plot.pr_curve(
y_true, y_scores, labels=class_names
)
})
wandb.finish()
3. Sweeps (Hyperparameter Tuning)¶
3.1 Sweep Configuration¶
"""
W&B Sweeps Configuration
"""
import wandb
# Sweep configuration
sweep_config = {
"name": "hyperparam-sweep",
"method": "bayes", # random, grid, bayes
"metric": {
"name": "val_accuracy",
"goal": "maximize"
},
"parameters": {
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-5,
"max": 1e-1
},
"batch_size": {
"values": [16, 32, 64, 128]
},
"epochs": {
"value": 50 # Fixed value
},
"optimizer": {
"values": ["adam", "sgd", "rmsprop"]
},
"hidden_dim": {
"distribution": "int_uniform",
"min": 32,
"max": 256
},
"dropout": {
"distribution": "uniform",
"min": 0.0,
"max": 0.5
}
},
"early_terminate": {
"type": "hyperband",
"min_iter": 5,
"eta": 3
}
}
# Create sweep
sweep_id = wandb.sweep(sweep_config, project="my-project")
print(f"Sweep ID: {sweep_id}")
3.2 Running Sweep Agent¶
"""
Sweep Training Function
"""
import wandb
import torch
def train_sweep():
"""Training function to run in sweep"""
# Initialize W&B (sweep provides config)
wandb.init()
config = wandb.config
# Create model
model = create_model(
hidden_dim=config.hidden_dim,
dropout=config.dropout
)
# Setup optimizer
if config.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
elif config.optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
else:
optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
# Training
for epoch in range(config.epochs):
train_loss = train_one_epoch(model, train_loader, optimizer)
val_accuracy = evaluate(model, val_loader)
wandb.log({
"train_loss": train_loss,
"val_accuracy": val_accuracy,
"epoch": epoch
})
wandb.finish()
# Run sweep
wandb.agent(
sweep_id,
function=train_sweep,
count=50 # Maximum number of runs
)
3.3 Running Sweep from CLI¶
# Create sweep.yaml file
# Start sweep
wandb sweep sweep.yaml
# Run agent (can run in parallel on multiple machines)
wandb agent username/project/sweep_id
# sweep.yaml
name: hyperparameter-sweep
method: bayes
metric:
name: val_accuracy
goal: maximize
parameters:
learning_rate:
distribution: log_uniform_values
min: 0.00001
max: 0.1
batch_size:
values: [16, 32, 64]
hidden_dim:
distribution: int_uniform
min: 64
max: 512
4. Artifacts¶
4.1 Dataset Version Control¶
"""
Managing Datasets with W&B Artifacts
"""
import wandb
# Create and upload artifact
wandb.init(project="dataset-versioning")
# Create dataset artifact
dataset_artifact = wandb.Artifact(
name="mnist-dataset",
type="dataset",
description="MNIST dataset for classification",
metadata={
"size": 70000,
"classes": 10,
"source": "torchvision"
}
)
# Add files/directories
dataset_artifact.add_file("data/train.csv")
dataset_artifact.add_dir("data/images/")
# Add remote reference (reference without download)
dataset_artifact.add_reference("s3://bucket/large_data/")
# Upload
wandb.log_artifact(dataset_artifact)
wandb.finish()
4.2 Model Artifacts¶
"""
Managing Model Artifacts
"""
import wandb
import torch
wandb.init(project="model-artifacts")
# After training...
# Create model artifact
model_artifact = wandb.Artifact(
name="churn-model",
type="model",
description="Customer churn prediction model",
metadata={
"accuracy": 0.95,
"framework": "pytorch",
"architecture": "MLP"
}
)
# Save and add model file
torch.save(model.state_dict(), "model.pth")
model_artifact.add_file("model.pth")
# Also add config file
model_artifact.add_file("config.yaml")
# Upload
wandb.log_artifact(model_artifact)
# Link model to specific alias
wandb.run.link_artifact(model_artifact, "model-registry/churn-model", aliases=["latest", "production"])
wandb.finish()
4.3 Using Artifacts¶
"""
Downloading and Using Artifacts
"""
import wandb
wandb.init(project="using-artifacts")
# Download artifact
artifact = wandb.use_artifact("mnist-dataset:latest") # or :v0, :v1, etc.
artifact_dir = artifact.download()
print(f"Downloaded to: {artifact_dir}")
# Direct access to artifact files
with artifact.file("train.csv") as f:
df = pd.read_csv(f)
# Dependency tracking (this run uses this artifact)
# use_artifact() handles this automatically
wandb.finish()
4.4 Artifact Lineage¶
"""
Tracking Artifact Lineage
"""
import wandb
# Data → Training → Model lineage
wandb.init(project="lineage-demo")
# 1. Input artifact (dataset)
dataset = wandb.use_artifact("processed-data:latest")
# 2. Perform training
# ...
# 3. Output artifact (model)
model_artifact = wandb.Artifact("trained-model", type="model")
model_artifact.add_file("model.pth")
wandb.log_artifact(model_artifact)
# Can view entire lineage graph in W&B UI
# Dataset → (training run) → Model
wandb.finish()
5. Comparison with MLflow¶
5.1 Feature Comparison¶
"""
MLflow vs W&B Comparison
"""
comparison = {
"Experiment Tracking": {
"MLflow": "Open source, self-hosted",
"W&B": "SaaS-based, free tier available"
},
"Visualization": {
"MLflow": "Basic visualization",
"W&B": "Rich visualization, real-time updates"
},
"Collaboration": {
"MLflow": "Limited",
"W&B": "Team features, report sharing"
},
"Hyperparameter Tuning": {
"MLflow": "Requires external tools (Optuna, etc.)",
"W&B": "Built-in Sweeps"
},
"Model Registry": {
"MLflow": "Full functionality",
"W&B": "Model Registry (recently added)"
},
"Deployment": {
"MLflow": "MLflow Serving",
"W&B": "No direct support (integrate other tools)"
},
"Cost": {
"MLflow": "Free (infrastructure costs only)",
"W&B": "Free tier + paid plans"
}
}
5.2 Using Together¶
"""
Using MLflow and W&B Together
"""
import mlflow
import wandb
# Initialize both platforms
wandb.init(project="dual-tracking")
mlflow.set_experiment("dual-tracking")
with mlflow.start_run():
# Common configuration
params = {"lr": 0.001, "epochs": 100}
# Log parameters to both
mlflow.log_params(params)
wandb.config.update(params)
# Training loop
for epoch in range(params["epochs"]):
loss = train_one_epoch()
accuracy = evaluate()
# Log metrics to both
mlflow.log_metrics({"loss": loss, "accuracy": accuracy}, step=epoch)
wandb.log({"loss": loss, "accuracy": accuracy, "epoch": epoch})
# Save model
mlflow.sklearn.log_model(model, "model")
wandb.save("model.pkl")
wandb.finish()
6. Advanced Features¶
6.1 Team Collaboration¶
"""
Team Project Setup
"""
import wandb
# Log to team project
wandb.init(
entity="team-name", # Team name
project="shared-project", # Project name
group="experiment-group", # Experiment group (group related experiments)
job_type="training" # Job type
)
6.2 Report Creation¶
"""
W&B Reports API
"""
import wandb
# Reports are typically created in UI, but also available via API
api = wandb.Api()
# Query all runs in project
runs = api.runs("username/project")
for run in runs:
print(f"Run: {run.name}")
print(f" Config: {run.config}")
print(f" Summary: {run.summary}")
print(f" History: {run.history().shape}")
6.3 Alert Configuration¶
"""
W&B Alerts
"""
import wandb
wandb.init(project="alerting-demo")
# Trigger alerts during training
for epoch in range(100):
accuracy = train_and_evaluate()
if accuracy > 0.95:
wandb.alert(
title="High Accuracy Achieved!",
text=f"Model achieved {accuracy:.2%} accuracy at epoch {epoch}",
level=wandb.AlertLevel.INFO
)
if accuracy < 0.5:
wandb.alert(
title="Training Issue",
text=f"Accuracy dropped to {accuracy:.2%}",
level=wandb.AlertLevel.WARN
)
wandb.log({"accuracy": accuracy, "epoch": epoch})
wandb.finish()
Exercises¶
Exercise 1: Basic Experiment Tracking¶
Train a CNN model on the MNIST dataset and track experiments with W&B.
Exercise 2: Run Sweeps¶
Execute a Bayesian optimization sweep for 3 or more hyperparameters.
Exercise 3: Artifacts¶
Save datasets and models as artifacts and verify lineage.
Summary¶
| Feature | W&B | MLflow |
|---|---|---|
| Experiment Tracking | wandb.log() | mlflow.log_metrics() |
| Hyperparameter Tuning | Sweeps | External tools |
| Data/Model Versioning | Artifacts | Model Registry |
| Visualization | Rich dashboards | Basic UI |
| Collaboration | Teams, Reports | Limited |
| Hosting | SaaS / Self-hosted | Self-hosted |