experiment_tracking.py

Download
python 304 lines 9.1 KB
  1"""
  2Weights & Biases Experiment Tracking Example
  3============================================
  4
  5W&B๋ฅผ ์‚ฌ์šฉํ•œ ์‹คํ—˜ ์ถ”์  ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.
  6
  7์‹คํ–‰ ๋ฐฉ๋ฒ•:
  8    # W&B ๋กœ๊ทธ์ธ
  9    wandb login
 10
 11    # ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰
 12    python experiment_tracking.py
 13"""
 14
 15import wandb
 16import numpy as np
 17from sklearn.datasets import load_breast_cancer
 18from sklearn.model_selection import train_test_split, cross_val_score
 19from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
 20from sklearn.linear_model import LogisticRegression
 21from sklearn.metrics import (
 22    accuracy_score,
 23    precision_score,
 24    recall_score,
 25    f1_score,
 26    roc_auc_score,
 27    confusion_matrix,
 28    classification_report
 29)
 30import matplotlib.pyplot as plt
 31import seaborn as sns
 32
 33# W&B ํ”„๋กœ์ ํŠธ ์„ค์ •
 34PROJECT_NAME = "breast-cancer-classification"
 35ENTITY = None  # ํŒ€ ์ด๋ฆ„ (๊ฐœ์ธ์ด๋ฉด None)
 36
 37
 38def load_data():
 39    """๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ๋ถ„ํ• """
 40    data = load_breast_cancer()
 41    X_train, X_test, y_train, y_test = train_test_split(
 42        data.data, data.target,
 43        test_size=0.2,
 44        random_state=42,
 45        stratify=data.target
 46    )
 47    return X_train, X_test, y_train, y_test, data.feature_names, data.target_names
 48
 49
 50def calculate_metrics(y_true, y_pred, y_proba=None):
 51    """๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ"""
 52    metrics = {
 53        "accuracy": accuracy_score(y_true, y_pred),
 54        "precision": precision_score(y_true, y_pred),
 55        "recall": recall_score(y_true, y_pred),
 56        "f1_score": f1_score(y_true, y_pred)
 57    }
 58
 59    if y_proba is not None:
 60        metrics["roc_auc"] = roc_auc_score(y_true, y_proba)
 61
 62    return metrics
 63
 64
 65def train_with_wandb(model_name, model, params, X_train, X_test, y_train, y_test, feature_names):
 66    """W&B๋กœ ์‹คํ—˜ ์ถ”์ ํ•˜๋ฉฐ ๋ชจ๋ธ ํ•™์Šต"""
 67
 68    # W&B ์ดˆ๊ธฐํ™”
 69    run = wandb.init(
 70        project=PROJECT_NAME,
 71        entity=ENTITY,
 72        name=model_name,
 73        config=params,
 74        tags=["baseline", model_name.lower()],
 75        notes=f"Training {model_name} on breast cancer dataset"
 76    )
 77
 78    # ์ถ”๊ฐ€ ์„ค์ • ๋กœ๊น…
 79    wandb.config.update({
 80        "train_size": len(X_train),
 81        "test_size": len(X_test),
 82        "n_features": X_train.shape[1]
 83    })
 84
 85    # ๋ชจ๋ธ ํ•™์Šต
 86    model.fit(X_train, y_train)
 87
 88    # ๊ต์ฐจ ๊ฒ€์ฆ
 89    cv_scores = cross_val_score(model, X_train, y_train, cv=5, scoring="accuracy")
 90    wandb.log({
 91        "cv_mean": cv_scores.mean(),
 92        "cv_std": cv_scores.std()
 93    })
 94
 95    # ํ•™์Šต ๊ณก์„  ์‹œ๋ฎฌ๋ ˆ์ด์…˜ (์ผ๋ถ€ ๋ชจ๋ธ์—์„œ)
 96    if hasattr(model, "n_estimators"):
 97        # ๋‹จ๊ณ„๋ณ„ ์„ฑ๋Šฅ ๊ธฐ๋ก
 98        for i in range(1, params.get("n_estimators", 100) + 1, 10):
 99            partial_model = type(model)(**{**params, "n_estimators": i})
100            partial_model.fit(X_train, y_train)
101            train_score = partial_model.score(X_train, y_train)
102            val_score = partial_model.score(X_test, y_test)
103            wandb.log({
104                "train_accuracy": train_score,
105                "val_accuracy": val_score,
106                "n_estimators": i
107            })
108
109    # ์ตœ์ข… ์˜ˆ์ธก
110    y_pred = model.predict(X_test)
111    y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
112
113    # ๋ฉ”ํŠธ๋ฆญ ๋กœ๊น…
114    metrics = calculate_metrics(y_test, y_pred, y_proba)
115    wandb.log(metrics)
116
117    # Confusion Matrix
118    cm = confusion_matrix(y_test, y_pred)
119    plt.figure(figsize=(8, 6))
120    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
121    plt.title(f"Confusion Matrix - {model_name}")
122    plt.xlabel("Predicted")
123    plt.ylabel("Actual")
124    wandb.log({"confusion_matrix": wandb.Image(plt)})
125    plt.close()
126
127    # Feature Importance (ํ•ด๋‹นํ•˜๋Š” ๊ฒฝ์šฐ)
128    if hasattr(model, "feature_importances_"):
129        importance = model.feature_importances_
130        indices = np.argsort(importance)[::-1][:15]  # ์ƒ์œ„ 15๊ฐœ
131
132        plt.figure(figsize=(12, 6))
133        plt.bar(range(len(indices)), importance[indices])
134        plt.xticks(range(len(indices)), [feature_names[i] for i in indices], rotation=45, ha="right")
135        plt.title(f"Feature Importance - {model_name}")
136        plt.tight_layout()
137        wandb.log({"feature_importance": wandb.Image(plt)})
138        plt.close()
139
140        # ํ…Œ์ด๋ธ”๋กœ๋„ ๋กœ๊น…
141        importance_data = [
142            [feature_names[i], importance[i]]
143            for i in indices
144        ]
145        table = wandb.Table(columns=["feature", "importance"], data=importance_data)
146        wandb.log({"feature_importance_table": table})
147
148    # ROC Curve (ํ™•๋ฅ  ์˜ˆ์ธก ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ)
149    if y_proba is not None:
150        wandb.log({
151            "roc_curve": wandb.plot.roc_curve(y_test, np.column_stack([1-y_proba, y_proba]))
152        })
153
154    # Classification Report
155    report = classification_report(y_test, y_pred, output_dict=True)
156    wandb.log({
157        "classification_report": report
158    })
159
160    # ๋ชจ๋ธ ์•„ํ‹ฐํŒฉํŠธ ์ €์žฅ
161    artifact = wandb.Artifact(
162        name=f"{model_name.lower()}-model",
163        type="model",
164        description=f"{model_name} trained on breast cancer dataset"
165    )
166    # ์‹ค์ œ ํ”„๋กœ๋•์…˜์—์„œ๋Š” ๋ชจ๋ธ ํŒŒ์ผ ์ €์žฅ
167    # artifact.add_file("model.pkl")
168    wandb.log_artifact(artifact)
169
170    # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
171    print(f"\n{model_name} ๊ฒฐ๊ณผ:")
172    print(f"  Accuracy: {metrics['accuracy']:.4f}")
173    print(f"  F1 Score: {metrics['f1_score']:.4f}")
174    print(f"  ROC AUC: {metrics.get('roc_auc', 'N/A')}")
175    print(f"  CV Mean: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")
176
177    # ์‹คํ–‰ ์ข…๋ฃŒ
178    wandb.finish()
179
180    return metrics
181
182
183def hyperparameter_sweep():
184    """ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์Šค์œ• ์˜ˆ์ œ"""
185
186    # ์Šค์œ• ์„ค์ •
187    sweep_config = {
188        "name": "rf-hyperparameter-sweep",
189        "method": "bayes",  # random, grid, bayes
190        "metric": {
191            "name": "val_accuracy",
192            "goal": "maximize"
193        },
194        "parameters": {
195            "n_estimators": {
196                "values": [50, 100, 150, 200]
197            },
198            "max_depth": {
199                "values": [3, 5, 7, 10, None]
200            },
201            "min_samples_split": {
202                "distribution": "int_uniform",
203                "min": 2,
204                "max": 20
205            },
206            "min_samples_leaf": {
207                "distribution": "int_uniform",
208                "min": 1,
209                "max": 10
210            }
211        }
212    }
213
214    def train_sweep():
215        """์Šค์œ•์—์„œ ์‹คํ–‰๋  ํ•™์Šต ํ•จ์ˆ˜"""
216        wandb.init()
217        config = wandb.config
218
219        X_train, X_test, y_train, y_test, _, _ = load_data()
220
221        model = RandomForestClassifier(
222            n_estimators=config.n_estimators,
223            max_depth=config.max_depth,
224            min_samples_split=config.min_samples_split,
225            min_samples_leaf=config.min_samples_leaf,
226            random_state=42
227        )
228
229        model.fit(X_train, y_train)
230        val_accuracy = model.score(X_test, y_test)
231
232        wandb.log({"val_accuracy": val_accuracy})
233        wandb.finish()
234
235    # ์Šค์œ• ์ƒ์„ฑ ๋ฐ ์‹คํ–‰
236    sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)
237    print(f"\n์Šค์œ• ID: {sweep_id}")
238    print("์Šค์œ•์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด:")
239    print(f"  wandb agent {sweep_id}")
240
241    # ๋กœ์ปฌ์—์„œ ์‹คํ–‰ (์„ ํƒ์ )
242    # wandb.agent(sweep_id, function=train_sweep, count=20)
243
244
245def main():
246    """๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜"""
247    print("="*60)
248    print("Weights & Biases ์‹คํ—˜ ์ถ”์  ์˜ˆ์ œ")
249    print("="*60)
250
251    # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
252    X_train, X_test, y_train, y_test, feature_names, target_names = load_data()
253    print(f"\n๋ฐ์ดํ„ฐ์…‹:")
254    print(f"  ํ•™์Šต ๋ฐ์ดํ„ฐ: {len(X_train)} ์ƒ˜ํ”Œ")
255    print(f"  ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ: {len(X_test)} ์ƒ˜ํ”Œ")
256    print(f"  ํ”ผ์ฒ˜ ์ˆ˜: {len(feature_names)}")
257
258    # ๋ชจ๋ธ ์ •์˜
259    models = [
260        (
261            "RandomForest",
262            RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42),
263            {"n_estimators": 100, "max_depth": 10}
264        ),
265        (
266            "GradientBoosting",
267            GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=42),
268            {"n_estimators": 100, "max_depth": 5, "learning_rate": 0.1}
269        ),
270        (
271            "LogisticRegression",
272            LogisticRegression(max_iter=1000, random_state=42),
273            {"max_iter": 1000, "solver": "lbfgs"}
274        )
275    ]
276
277    # ๋ชจ๋ธ ํ•™์Šต
278    results = {}
279    for model_name, model, params in models:
280        print(f"\n{'='*40}")
281        print(f"{model_name} ํ•™์Šต ์ค‘...")
282        metrics = train_with_wandb(
283            model_name, model, params,
284            X_train, X_test, y_train, y_test,
285            feature_names
286        )
287        results[model_name] = metrics
288
289    # ๊ฒฐ๊ณผ ์š”์•ฝ
290    print("\n" + "="*60)
291    print("๊ฒฐ๊ณผ ์š”์•ฝ")
292    print("="*60)
293    print(f"\n{'Model':<20} {'Accuracy':<12} {'F1 Score':<12} {'ROC AUC':<12}")
294    print("-"*60)
295    for name, metrics in results.items():
296        print(f"{name:<20} {metrics['accuracy']:<12.4f} {metrics['f1_score']:<12.4f} {metrics.get('roc_auc', 0):<12.4f}")
297
298    print(f"\nW&B ๋Œ€์‹œ๋ณด๋“œ์—์„œ ์ž์„ธํ•œ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์„ธ์š”:")
299    print(f"  https://wandb.ai/{ENTITY or 'your-username'}/{PROJECT_NAME}")
300
301
302if __name__ == "__main__":
303    main()