Weights & Biases (W&B)
Weights & Biases (W&B)¶
1. W&B κ°μ¶
Weights & Biasesλ ML μ€ν μΆμ , νμ΄νΌνλΌλ―Έν° νλ, λͺ¨λΈ κ΄λ¦¬λ₯Ό μν νλ«νΌμ λλ€.
1.1 ν΅μ¬ κΈ°λ₯¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Weights & Biases κΈ°λ₯ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββ β
β β Experiments β β Sweeps β β Artifacts β β
β β β β β β β β
β β - μ€ν μΆμ β β - νμ΄νΌ β β - λ°μ΄ν°μ
β β
β β - λ©νΈλ¦ β β νλΌλ―Έν° β β - λͺ¨λΈ β β
β β - μκ°ν β β νλ β β - λ²μ κ΄λ¦¬ β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββ β
β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββ β
β β Tables β β Reports β β Models β β
β β β β β β β β
β β - λ°μ΄ν° β β - λ¬Έμν β β - λͺ¨λΈ β β
β β μκ°ν β β - 곡μ β β λ μ§μ€νΈλ¦¬ β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 μ€μΉ λ° μ€μ ¶
# μ€μΉ
pip install wandb
# λ‘κ·ΈμΈ
wandb login
# API ν€ μ
λ ₯ (https://wandb.ai/authorize)
# νκ²½ λ³μλ‘ μ€μ
export WANDB_API_KEY=your-api-key
# Pythonμμ λ‘κ·ΈμΈ
import wandb
wandb.login(key="your-api-key")
2. κΈ°λ³Έ μ€ν μΆμ ¶
2.1 첫 λ²μ§Έ μ€ν¶
"""
W&B κΈ°λ³Έ μ¬μ©λ²
"""
import wandb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
# λ°μ΄ν° μ€λΉ
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
# W&B μ΄κΈ°ν
wandb.init(
project="iris-classification", # νλ‘μ νΈ μ΄λ¦
name="random-forest-baseline", # μ€ν μ΄λ¦
config={ # νμ΄νΌνλΌλ―Έν°
"n_estimators": 100,
"max_depth": 5,
"random_state": 42
},
tags=["baseline", "random-forest"],
notes="Initial baseline experiment"
)
# config μ κ·Ό
config = wandb.config
# λͺ¨λΈ νμ΅
model = RandomForestClassifier(
n_estimators=config.n_estimators,
max_depth=config.max_depth,
random_state=config.random_state
)
model.fit(X_train, y_train)
# μμΈ‘ λ° νκ°
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
# λ©νΈλ¦ λ‘κΉ
wandb.log({
"accuracy": accuracy,
"test_size": len(X_test),
"train_size": len(X_train)
})
# μ€ν μ’
λ£
wandb.finish()
2.2 νμ΅ κ³Όμ λ‘κΉ ¶
"""
νμ΅ κ³Όμ μ€μκ° λ‘κΉ
"""
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# μ΄κΈ°ν
wandb.init(project="pytorch-training")
# λͺ¨λΈ μ μ
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.get("lr", 0.001))
# W&Bμμ λͺ¨λΈ κ·Έλν μΆμ
wandb.watch(model, criterion, log="all", log_freq=100)
# νμ΅ λ£¨ν
for epoch in range(wandb.config.get("epochs", 10)):
model.train()
train_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# λ°°μΉλ³ λ‘κΉ
(μ ν)
if batch_idx % 100 == 0:
wandb.log({
"batch_loss": loss.item(),
"epoch": epoch,
"batch": batch_idx
})
# μνλ³ λ‘κΉ
avg_loss = train_loss / len(train_loader)
val_accuracy = evaluate(model, val_loader)
wandb.log({
"epoch": epoch,
"train_loss": avg_loss,
"val_accuracy": val_accuracy
})
# 체ν¬ν¬μΈνΈ μ μ₯
if val_accuracy > best_accuracy:
torch.save(model.state_dict(), "best_model.pth")
wandb.save("best_model.pth")
best_accuracy = val_accuracy
wandb.finish()
2.3 λ€μν λ°μ΄ν° λ‘κΉ ¶
"""
λ€μν λ°μ΄ν° νμ
λ‘κΉ
"""
import wandb
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
wandb.init(project="data-logging-demo")
# 1. μ΄λ―Έμ§ λ‘κΉ
images = wandb.Image(
np.random.rand(100, 100, 3),
caption="Random Image"
)
wandb.log({"random_image": images})
# PIL μ΄λ―Έμ§
pil_image = Image.open("sample.png")
wandb.log({"pil_image": wandb.Image(pil_image)})
# μ¬λ¬ μ΄λ―Έμ§
wandb.log({
"examples": [wandb.Image(img, caption=f"Sample {i}")
for i, img in enumerate(image_batch[:10])]
})
# 2. νλ‘― λ‘κΉ
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
ax.set_title("Training Curve")
wandb.log({"plot": wandb.Image(fig)})
plt.close()
# λλ plotly μ¬μ©
import plotly.express as px
fig = px.scatter(x=[1, 2, 3], y=[1, 4, 9])
wandb.log({"plotly_chart": fig})
# 3. νμ€ν κ·Έλ¨
wandb.log({"predictions": wandb.Histogram(predictions)})
# 4. ν
μ΄λΈ
columns = ["id", "image", "prediction", "label"]
data = [
[i, wandb.Image(img), pred, label]
for i, (img, pred, label) in enumerate(zip(images, preds, labels))
]
table = wandb.Table(columns=columns, data=data)
wandb.log({"predictions_table": table})
# 5. Confusion Matrix
wandb.log({
"confusion_matrix": wandb.plot.confusion_matrix(
y_true=y_true,
preds=y_pred,
class_names=class_names
)
})
# 6. ROC Curve
wandb.log({
"roc_curve": wandb.plot.roc_curve(
y_true, y_scores, labels=class_names
)
})
# 7. PR Curve
wandb.log({
"pr_curve": wandb.plot.pr_curve(
y_true, y_scores, labels=class_names
)
})
wandb.finish()
3. Sweeps (νμ΄νΌνλΌλ―Έν° νλ)¶
3.1 Sweep μ€μ ¶
"""
W&B Sweeps μ€μ
"""
import wandb
# Sweep μ€μ
sweep_config = {
"name": "hyperparam-sweep",
"method": "bayes", # random, grid, bayes
"metric": {
"name": "val_accuracy",
"goal": "maximize"
},
"parameters": {
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-5,
"max": 1e-1
},
"batch_size": {
"values": [16, 32, 64, 128]
},
"epochs": {
"value": 50 # κ³ μ κ°
},
"optimizer": {
"values": ["adam", "sgd", "rmsprop"]
},
"hidden_dim": {
"distribution": "int_uniform",
"min": 32,
"max": 256
},
"dropout": {
"distribution": "uniform",
"min": 0.0,
"max": 0.5
}
},
"early_terminate": {
"type": "hyperband",
"min_iter": 5,
"eta": 3
}
}
# Sweep μμ±
sweep_id = wandb.sweep(sweep_config, project="my-project")
print(f"Sweep ID: {sweep_id}")
3.2 Sweep Agent μ€ν¶
"""
Sweep νμ΅ ν¨μ
"""
import wandb
import torch
def train_sweep():
"""Sweepμμ μ€νλ νμ΅ ν¨μ"""
# W&B μ΄κΈ°ν (sweepμ΄ configλ₯Ό μ 곡)
wandb.init()
config = wandb.config
# λͺ¨λΈ μμ±
model = create_model(
hidden_dim=config.hidden_dim,
dropout=config.dropout
)
# μ΅ν°λ§μ΄μ μ€μ
if config.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
elif config.optimizer == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
else:
optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)
# λ°μ΄ν°λ‘λ
train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
# νμ΅
for epoch in range(config.epochs):
train_loss = train_one_epoch(model, train_loader, optimizer)
val_accuracy = evaluate(model, val_loader)
wandb.log({
"train_loss": train_loss,
"val_accuracy": val_accuracy,
"epoch": epoch
})
wandb.finish()
# Sweep μ€ν
wandb.agent(
sweep_id,
function=train_sweep,
count=50 # μ΅λ μ€ν νμ
)
3.3 CLIμμ Sweep μ€ν¶
# sweep.yaml νμΌ μμ±
# sweep μμ
wandb sweep sweep.yaml
# Agent μ€ν (μ¬λ¬ λ¨Έμ μμ λ³λ ¬ κ°λ₯)
wandb agent username/project/sweep_id
# sweep.yaml
name: hyperparameter-sweep
method: bayes
metric:
name: val_accuracy
goal: maximize
parameters:
learning_rate:
distribution: log_uniform_values
min: 0.00001
max: 0.1
batch_size:
values: [16, 32, 64]
hidden_dim:
distribution: int_uniform
min: 64
max: 512
4. Artifacts¶
4.1 λ°μ΄ν°μ λ²μ κ΄λ¦¬¶
"""
W&B Artifactsλ‘ λ°μ΄ν°μ
κ΄λ¦¬
"""
import wandb
# μν°ν©νΈ μμ± λ° μ
λ‘λ
wandb.init(project="dataset-versioning")
# λ°μ΄ν°μ
μν°ν©νΈ μμ±
dataset_artifact = wandb.Artifact(
name="mnist-dataset",
type="dataset",
description="MNIST dataset for classification",
metadata={
"size": 70000,
"classes": 10,
"source": "torchvision"
}
)
# νμΌ/λλ ν 리 μΆκ°
dataset_artifact.add_file("data/train.csv")
dataset_artifact.add_dir("data/images/")
# μ격 μ°Έμ‘° μΆκ° (λ€μ΄λ‘λ μμ΄ μ°Έμ‘°λ§)
dataset_artifact.add_reference("s3://bucket/large_data/")
# μ
λ‘λ
wandb.log_artifact(dataset_artifact)
wandb.finish()
4.2 λͺ¨λΈ μν°ν©νΈ¶
"""
λͺ¨λΈ μν°ν©νΈ κ΄λ¦¬
"""
import wandb
import torch
wandb.init(project="model-artifacts")
# νμ΅ ν...
# λͺ¨λΈ μν°ν©νΈ μμ±
model_artifact = wandb.Artifact(
name="churn-model",
type="model",
description="Customer churn prediction model",
metadata={
"accuracy": 0.95,
"framework": "pytorch",
"architecture": "MLP"
}
)
# λͺ¨λΈ νμΌ μ μ₯ λ° μΆκ°
torch.save(model.state_dict(), "model.pth")
model_artifact.add_file("model.pth")
# μ€μ νμΌλ ν¨κ»
model_artifact.add_file("config.yaml")
# μ
λ‘λ
wandb.log_artifact(model_artifact)
# λͺ¨λΈμ νΉμ λ³μΉμΌλ‘ μ°κ²°
wandb.run.link_artifact(model_artifact, "model-registry/churn-model", aliases=["latest", "production"])
wandb.finish()
4.3 μν°ν©νΈ μ¬μ©¶
"""
μν°ν©νΈ λ€μ΄λ‘λ λ° μ¬μ©
"""
import wandb
wandb.init(project="using-artifacts")
# μν°ν©νΈ λ€μ΄λ‘λ
artifact = wandb.use_artifact("mnist-dataset:latest") # λλ :v0, :v1 λ±
artifact_dir = artifact.download()
print(f"Downloaded to: {artifact_dir}")
# μν°ν©νΈ νμΌ μ§μ μ κ·Ό
with artifact.file("train.csv") as f:
df = pd.read_csv(f)
# μμ‘΄μ± κΈ°λ‘ (μ΄ runμ΄ μ΄ artifactλ₯Ό μ¬μ©ν¨)
# use_artifact()κ° μλμΌλ‘ μ²λ¦¬
wandb.finish()
4.4 μν°ν©νΈ 리λμ§¶
"""
μν°ν©νΈ 리λμ§ μΆμ
"""
import wandb
# λ°μ΄ν° β νμ΅ β λͺ¨λΈ 리λμ§
wandb.init(project="lineage-demo")
# 1. μ
λ ₯ μν°ν©νΈ (λ°μ΄ν°μ
)
dataset = wandb.use_artifact("processed-data:latest")
# 2. νμ΅ μν
# ...
# 3. μΆλ ₯ μν°ν©νΈ (λͺ¨λΈ)
model_artifact = wandb.Artifact("trained-model", type="model")
model_artifact.add_file("model.pth")
wandb.log_artifact(model_artifact)
# W&B UIμμ μ 체 리λμ§ κ·Έλν νμΈ κ°λ₯
# λ°μ΄ν°μ
β (νμ΅ run) β λͺ¨λΈ
wandb.finish()
5. MLflowμ λΉκ΅¶
5.1 κΈ°λ₯ λΉκ΅¶
"""
MLflow vs W&B λΉκ΅
"""
comparison = {
"μ€ν μΆμ ": {
"MLflow": "μ€νμμ€, self-hosted",
"W&B": "SaaS κΈ°λ°, λ¬΄λ£ ν°μ΄ μ 곡"
},
"μκ°ν": {
"MLflow": "κΈ°λ³Έ μκ°ν",
"W&B": "νλΆν μκ°ν, μ€μκ° μ
λ°μ΄νΈ"
},
"νμ
": {
"MLflow": "μ νμ ",
"W&B": "ν κΈ°λ₯, 리ν¬νΈ 곡μ "
},
"νμ΄νΌνλΌλ―Έν° νλ": {
"MLflow": "μΈλΆ λꡬ νμ (Optuna λ±)",
"W&B": "Sweeps λ΄μ₯"
},
"λͺ¨λΈ λ μ§μ€νΈλ¦¬": {
"MLflow": "μμ ν κΈ°λ₯",
"W&B": "Model Registry (μ΅κ·Ό μΆκ°)"
},
"λ°°ν¬": {
"MLflow": "MLflow Serving",
"W&B": "μ§μ μ§μ μμ (λ€λ₯Έ λꡬ μ°λ)"
},
"λΉμ©": {
"MLflow": "λ¬΄λ£ (μΈνλΌ λΉμ©λ§)",
"W&B": "λ¬΄λ£ ν°μ΄ + μ λ£ νλ"
}
}
5.2 ν¨κ» μ¬μ©νκΈ°¶
"""
MLflowμ W&B λμ μ¬μ©
"""
import mlflow
import wandb
# λ νλ«νΌ λͺ¨λ μ΄κΈ°ν
wandb.init(project="dual-tracking")
mlflow.set_experiment("dual-tracking")
with mlflow.start_run():
# κ³΅ν΅ μ€μ
params = {"lr": 0.001, "epochs": 100}
# μμͺ½μ νλΌλ―Έν° λ‘κΉ
mlflow.log_params(params)
wandb.config.update(params)
# νμ΅ λ£¨ν
for epoch in range(params["epochs"]):
loss = train_one_epoch()
accuracy = evaluate()
# μμͺ½μ λ©νΈλ¦ λ‘κΉ
mlflow.log_metrics({"loss": loss, "accuracy": accuracy}, step=epoch)
wandb.log({"loss": loss, "accuracy": accuracy, "epoch": epoch})
# λͺ¨λΈ μ μ₯
mlflow.sklearn.log_model(model, "model")
wandb.save("model.pkl")
wandb.finish()
6. κ³ κΈ κΈ°λ₯¶
6.1 ν νμ ¶
"""
ν νλ‘μ νΈ μ€μ
"""
import wandb
# ν νλ‘μ νΈμ λ‘κΉ
wandb.init(
entity="team-name", # ν μ΄λ¦
project="shared-project", # νλ‘μ νΈ μ΄λ¦
group="experiment-group", # μ€ν κ·Έλ£Ή (κ΄λ ¨ μ€ν λ¬ΆκΈ°)
job_type="training" # μμ
μ ν
)
6.2 리ν¬νΈ μμ±¶
"""
W&B Reports API
"""
import wandb
# 리ν¬νΈλ μ£Όλ‘ UIμμ μμ±νμ§λ§ APIλ‘λ κ°λ₯
api = wandb.Api()
# νλ‘μ νΈμ λͺ¨λ μ€ν μ‘°ν
runs = api.runs("username/project")
for run in runs:
print(f"Run: {run.name}")
print(f" Config: {run.config}")
print(f" Summary: {run.summary}")
print(f" History: {run.history().shape}")
6.3 μλ¦Ό μ€μ ¶
"""
W&B Alerts
"""
import wandb
wandb.init(project="alerting-demo")
# νμ΅ μ€ μλ¦Ό νΈλ¦¬κ±°
for epoch in range(100):
accuracy = train_and_evaluate()
if accuracy > 0.95:
wandb.alert(
title="High Accuracy Achieved!",
text=f"Model achieved {accuracy:.2%} accuracy at epoch {epoch}",
level=wandb.AlertLevel.INFO
)
if accuracy < 0.5:
wandb.alert(
title="Training Issue",
text=f"Accuracy dropped to {accuracy:.2%}",
level=wandb.AlertLevel.WARN
)
wandb.log({"accuracy": accuracy, "epoch": epoch})
wandb.finish()
μ°μ΅ λ¬Έμ ¶
λ¬Έμ 1: κΈ°λ³Έ μ€ν μΆμ ¶
MNIST λ°μ΄ν°μ μΌλ‘ CNN λͺ¨λΈμ νμ΅νκ³ W&Bλ‘ μ€νμ μΆμ νμΈμ.
λ¬Έμ 2: Sweeps μ€ν¶
3κ° μ΄μμ νμ΄νΌνλΌλ―Έν°μ λν΄ Bayesian μ΅μ ν sweepμ μ€ννμΈμ.
λ¬Έμ 3: Artifacts¶
λ°μ΄ν°μ κ³Ό λͺ¨λΈμ μν°ν©νΈλ‘ μ μ₯νκ³ λ¦¬λμ§λ₯Ό νμΈνμΈμ.
μμ½¶
| κΈ°λ₯ | W&B | MLflow |
|---|---|---|
| μ€ν μΆμ | wandb.log() | mlflow.log_metrics() |
| νμ΄νΌνλΌλ―Έν° νλ | Sweeps | μΈλΆ λꡬ |
| λ°μ΄ν°/λͺ¨λΈ λ²μ | Artifacts | Model Registry |
| μκ°ν | νλΆν λμ보λ | κΈ°λ³Έ UI |
| νμ | ν, 리ν¬νΈ | μ νμ |
| νΈμ€ν | SaaS / Self-hosted | Self-hosted |