1"""
2Weights & Biases Experiment Tracking Example
3============================================
4
5W&B๋ฅผ ์ฌ์ฉํ ์คํ ์ถ์ ์์ ์
๋๋ค.
6
7์คํ ๋ฐฉ๋ฒ:
8 # W&B ๋ก๊ทธ์ธ
9 wandb login
10
11 # ์คํฌ๋ฆฝํธ ์คํ
12 python experiment_tracking.py
13"""
14
15import wandb
16import numpy as np
17from sklearn.datasets import load_breast_cancer
18from sklearn.model_selection import train_test_split, cross_val_score
19from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
20from sklearn.linear_model import LogisticRegression
21from sklearn.metrics import (
22 accuracy_score,
23 precision_score,
24 recall_score,
25 f1_score,
26 roc_auc_score,
27 confusion_matrix,
28 classification_report
29)
30import matplotlib.pyplot as plt
31import seaborn as sns
32
33# W&B ํ๋ก์ ํธ ์ค์
34PROJECT_NAME = "breast-cancer-classification"
35ENTITY = None # ํ ์ด๋ฆ (๊ฐ์ธ์ด๋ฉด None)
36
37
38def load_data():
39 """๋ฐ์ดํฐ ๋ก๋ ๋ฐ ๋ถํ """
40 data = load_breast_cancer()
41 X_train, X_test, y_train, y_test = train_test_split(
42 data.data, data.target,
43 test_size=0.2,
44 random_state=42,
45 stratify=data.target
46 )
47 return X_train, X_test, y_train, y_test, data.feature_names, data.target_names
48
49
50def calculate_metrics(y_true, y_pred, y_proba=None):
51 """๋ฉํธ๋ฆญ ๊ณ์ฐ"""
52 metrics = {
53 "accuracy": accuracy_score(y_true, y_pred),
54 "precision": precision_score(y_true, y_pred),
55 "recall": recall_score(y_true, y_pred),
56 "f1_score": f1_score(y_true, y_pred)
57 }
58
59 if y_proba is not None:
60 metrics["roc_auc"] = roc_auc_score(y_true, y_proba)
61
62 return metrics
63
64
65def train_with_wandb(model_name, model, params, X_train, X_test, y_train, y_test, feature_names):
66 """W&B๋ก ์คํ ์ถ์ ํ๋ฉฐ ๋ชจ๋ธ ํ์ต"""
67
68 # W&B ์ด๊ธฐํ
69 run = wandb.init(
70 project=PROJECT_NAME,
71 entity=ENTITY,
72 name=model_name,
73 config=params,
74 tags=["baseline", model_name.lower()],
75 notes=f"Training {model_name} on breast cancer dataset"
76 )
77
78 # ์ถ๊ฐ ์ค์ ๋ก๊น
79 wandb.config.update({
80 "train_size": len(X_train),
81 "test_size": len(X_test),
82 "n_features": X_train.shape[1]
83 })
84
85 # ๋ชจ๋ธ ํ์ต
86 model.fit(X_train, y_train)
87
88 # ๊ต์ฐจ ๊ฒ์ฆ
89 cv_scores = cross_val_score(model, X_train, y_train, cv=5, scoring="accuracy")
90 wandb.log({
91 "cv_mean": cv_scores.mean(),
92 "cv_std": cv_scores.std()
93 })
94
95 # ํ์ต ๊ณก์ ์๋ฎฌ๋ ์ด์
(์ผ๋ถ ๋ชจ๋ธ์์)
96 if hasattr(model, "n_estimators"):
97 # ๋จ๊ณ๋ณ ์ฑ๋ฅ ๊ธฐ๋ก
98 for i in range(1, params.get("n_estimators", 100) + 1, 10):
99 partial_model = type(model)(**{**params, "n_estimators": i})
100 partial_model.fit(X_train, y_train)
101 train_score = partial_model.score(X_train, y_train)
102 val_score = partial_model.score(X_test, y_test)
103 wandb.log({
104 "train_accuracy": train_score,
105 "val_accuracy": val_score,
106 "n_estimators": i
107 })
108
109 # ์ต์ข
์์ธก
110 y_pred = model.predict(X_test)
111 y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
112
113 # ๋ฉํธ๋ฆญ ๋ก๊น
114 metrics = calculate_metrics(y_test, y_pred, y_proba)
115 wandb.log(metrics)
116
117 # Confusion Matrix
118 cm = confusion_matrix(y_test, y_pred)
119 plt.figure(figsize=(8, 6))
120 sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
121 plt.title(f"Confusion Matrix - {model_name}")
122 plt.xlabel("Predicted")
123 plt.ylabel("Actual")
124 wandb.log({"confusion_matrix": wandb.Image(plt)})
125 plt.close()
126
127 # Feature Importance (ํด๋นํ๋ ๊ฒฝ์ฐ)
128 if hasattr(model, "feature_importances_"):
129 importance = model.feature_importances_
130 indices = np.argsort(importance)[::-1][:15] # ์์ 15๊ฐ
131
132 plt.figure(figsize=(12, 6))
133 plt.bar(range(len(indices)), importance[indices])
134 plt.xticks(range(len(indices)), [feature_names[i] for i in indices], rotation=45, ha="right")
135 plt.title(f"Feature Importance - {model_name}")
136 plt.tight_layout()
137 wandb.log({"feature_importance": wandb.Image(plt)})
138 plt.close()
139
140 # ํ
์ด๋ธ๋ก๋ ๋ก๊น
141 importance_data = [
142 [feature_names[i], importance[i]]
143 for i in indices
144 ]
145 table = wandb.Table(columns=["feature", "importance"], data=importance_data)
146 wandb.log({"feature_importance_table": table})
147
148 # ROC Curve (ํ๋ฅ ์์ธก ๊ฐ๋ฅํ ๊ฒฝ์ฐ)
149 if y_proba is not None:
150 wandb.log({
151 "roc_curve": wandb.plot.roc_curve(y_test, np.column_stack([1-y_proba, y_proba]))
152 })
153
154 # Classification Report
155 report = classification_report(y_test, y_pred, output_dict=True)
156 wandb.log({
157 "classification_report": report
158 })
159
160 # ๋ชจ๋ธ ์ํฐํฉํธ ์ ์ฅ
161 artifact = wandb.Artifact(
162 name=f"{model_name.lower()}-model",
163 type="model",
164 description=f"{model_name} trained on breast cancer dataset"
165 )
166 # ์ค์ ํ๋ก๋์
์์๋ ๋ชจ๋ธ ํ์ผ ์ ์ฅ
167 # artifact.add_file("model.pkl")
168 wandb.log_artifact(artifact)
169
170 # ๊ฒฐ๊ณผ ์ถ๋ ฅ
171 print(f"\n{model_name} ๊ฒฐ๊ณผ:")
172 print(f" Accuracy: {metrics['accuracy']:.4f}")
173 print(f" F1 Score: {metrics['f1_score']:.4f}")
174 print(f" ROC AUC: {metrics.get('roc_auc', 'N/A')}")
175 print(f" CV Mean: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")
176
177 # ์คํ ์ข
๋ฃ
178 wandb.finish()
179
180 return metrics
181
182
183def hyperparameter_sweep():
184 """ํ์ดํผํ๋ผ๋ฏธํฐ ์ค์ ์์ """
185
186 # ์ค์ ์ค์
187 sweep_config = {
188 "name": "rf-hyperparameter-sweep",
189 "method": "bayes", # random, grid, bayes
190 "metric": {
191 "name": "val_accuracy",
192 "goal": "maximize"
193 },
194 "parameters": {
195 "n_estimators": {
196 "values": [50, 100, 150, 200]
197 },
198 "max_depth": {
199 "values": [3, 5, 7, 10, None]
200 },
201 "min_samples_split": {
202 "distribution": "int_uniform",
203 "min": 2,
204 "max": 20
205 },
206 "min_samples_leaf": {
207 "distribution": "int_uniform",
208 "min": 1,
209 "max": 10
210 }
211 }
212 }
213
214 def train_sweep():
215 """์ค์์์ ์คํ๋ ํ์ต ํจ์"""
216 wandb.init()
217 config = wandb.config
218
219 X_train, X_test, y_train, y_test, _, _ = load_data()
220
221 model = RandomForestClassifier(
222 n_estimators=config.n_estimators,
223 max_depth=config.max_depth,
224 min_samples_split=config.min_samples_split,
225 min_samples_leaf=config.min_samples_leaf,
226 random_state=42
227 )
228
229 model.fit(X_train, y_train)
230 val_accuracy = model.score(X_test, y_test)
231
232 wandb.log({"val_accuracy": val_accuracy})
233 wandb.finish()
234
235 # ์ค์ ์์ฑ ๋ฐ ์คํ
236 sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)
237 print(f"\n์ค์ ID: {sweep_id}")
238 print("์ค์์ ์คํํ๋ ค๋ฉด:")
239 print(f" wandb agent {sweep_id}")
240
241 # ๋ก์ปฌ์์ ์คํ (์ ํ์ )
242 # wandb.agent(sweep_id, function=train_sweep, count=20)
243
244
245def main():
246 """๋ฉ์ธ ์คํ ํจ์"""
247 print("="*60)
248 print("Weights & Biases ์คํ ์ถ์ ์์ ")
249 print("="*60)
250
251 # ๋ฐ์ดํฐ ๋ก๋
252 X_train, X_test, y_train, y_test, feature_names, target_names = load_data()
253 print(f"\n๋ฐ์ดํฐ์
:")
254 print(f" ํ์ต ๋ฐ์ดํฐ: {len(X_train)} ์ํ")
255 print(f" ํ
์คํธ ๋ฐ์ดํฐ: {len(X_test)} ์ํ")
256 print(f" ํผ์ฒ ์: {len(feature_names)}")
257
258 # ๋ชจ๋ธ ์ ์
259 models = [
260 (
261 "RandomForest",
262 RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42),
263 {"n_estimators": 100, "max_depth": 10}
264 ),
265 (
266 "GradientBoosting",
267 GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=42),
268 {"n_estimators": 100, "max_depth": 5, "learning_rate": 0.1}
269 ),
270 (
271 "LogisticRegression",
272 LogisticRegression(max_iter=1000, random_state=42),
273 {"max_iter": 1000, "solver": "lbfgs"}
274 )
275 ]
276
277 # ๋ชจ๋ธ ํ์ต
278 results = {}
279 for model_name, model, params in models:
280 print(f"\n{'='*40}")
281 print(f"{model_name} ํ์ต ์ค...")
282 metrics = train_with_wandb(
283 model_name, model, params,
284 X_train, X_test, y_train, y_test,
285 feature_names
286 )
287 results[model_name] = metrics
288
289 # ๊ฒฐ๊ณผ ์์ฝ
290 print("\n" + "="*60)
291 print("๊ฒฐ๊ณผ ์์ฝ")
292 print("="*60)
293 print(f"\n{'Model':<20} {'Accuracy':<12} {'F1 Score':<12} {'ROC AUC':<12}")
294 print("-"*60)
295 for name, metrics in results.items():
296 print(f"{name:<20} {metrics['accuracy']:<12.4f} {metrics['f1_score']:<12.4f} {metrics.get('roc_auc', 0):<12.4f}")
297
298 print(f"\nW&B ๋์๋ณด๋์์ ์์ธํ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ์ธ์:")
299 print(f" https://wandb.ai/{ENTITY or 'your-username'}/{PROJECT_NAME}")
300
301
302if __name__ == "__main__":
303 main()