tracking_example.py

Download
python 190 lines 5.7 KB
  1"""
  2MLflow Tracking Example
  3=======================
  4
  5MLflow๋ฅผ ์‚ฌ์šฉํ•œ ์‹คํ—˜ ์ถ”์  ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.
  6
  7์‹คํ–‰ ๋ฐฉ๋ฒ•:
  8    # MLflow ์„œ๋ฒ„ ์‹œ์ž‘
  9    mlflow server --backend-store-uri sqlite:///mlflow.db --port 5000
 10
 11    # ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰
 12    python tracking_example.py
 13"""
 14
 15import mlflow
 16import mlflow.sklearn
 17from sklearn.datasets import load_iris
 18from sklearn.model_selection import train_test_split, cross_val_score
 19from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
 20from sklearn.linear_model import LogisticRegression
 21from sklearn.metrics import (
 22    accuracy_score,
 23    precision_score,
 24    recall_score,
 25    f1_score,
 26    confusion_matrix
 27)
 28import matplotlib.pyplot as plt
 29import seaborn as sns
 30import numpy as np
 31import os
 32
 33# MLflow ์„ค์ •
 34TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")
 35EXPERIMENT_NAME = "iris-classification-demo"
 36
 37
 38def setup_mlflow():
 39    """MLflow ์ดˆ๊ธฐํ™”"""
 40    mlflow.set_tracking_uri(TRACKING_URI)
 41    mlflow.set_experiment(EXPERIMENT_NAME)
 42    print(f"MLflow Tracking URI: {TRACKING_URI}")
 43    print(f"Experiment: {EXPERIMENT_NAME}")
 44
 45
 46def load_data():
 47    """๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ๋ถ„ํ• """
 48    iris = load_iris()
 49    X_train, X_test, y_train, y_test = train_test_split(
 50        iris.data, iris.target,
 51        test_size=0.2,
 52        random_state=42,
 53        stratify=iris.target
 54    )
 55    return X_train, X_test, y_train, y_test, iris.target_names
 56
 57
 58def calculate_metrics(y_true, y_pred):
 59    """๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ"""
 60    return {
 61        "accuracy": accuracy_score(y_true, y_pred),
 62        "precision_macro": precision_score(y_true, y_pred, average="macro"),
 63        "recall_macro": recall_score(y_true, y_pred, average="macro"),
 64        "f1_macro": f1_score(y_true, y_pred, average="macro")
 65    }
 66
 67
 68def plot_confusion_matrix(y_true, y_pred, class_names):
 69    """Confusion Matrix ์‹œ๊ฐํ™”"""
 70    cm = confusion_matrix(y_true, y_pred)
 71    fig, ax = plt.subplots(figsize=(8, 6))
 72    sns.heatmap(
 73        cm, annot=True, fmt="d", cmap="Blues",
 74        xticklabels=class_names,
 75        yticklabels=class_names,
 76        ax=ax
 77    )
 78    ax.set_xlabel("Predicted")
 79    ax.set_ylabel("Actual")
 80    ax.set_title("Confusion Matrix")
 81    return fig
 82
 83
 84def train_and_log(model, model_name, params, X_train, X_test, y_train, y_test, class_names):
 85    """๋ชจ๋ธ ํ•™์Šต ๋ฐ MLflow ๋กœ๊น…"""
 86    with mlflow.start_run(run_name=model_name):
 87        # ํŒŒ๋ผ๋ฏธํ„ฐ ๋กœ๊น…
 88        mlflow.log_params(params)
 89        mlflow.log_param("model_type", model_name)
 90        mlflow.log_param("train_size", len(X_train))
 91        mlflow.log_param("test_size", len(X_test))
 92
 93        # ๋ชจ๋ธ ํ•™์Šต
 94        model.fit(X_train, y_train)
 95
 96        # ๊ต์ฐจ ๊ฒ€์ฆ
 97        cv_scores = cross_val_score(model, X_train, y_train, cv=5)
 98        mlflow.log_metric("cv_mean", cv_scores.mean())
 99        mlflow.log_metric("cv_std", cv_scores.std())
100
101        # ์˜ˆ์ธก ๋ฐ ํ‰๊ฐ€
102        y_pred = model.predict(X_test)
103        metrics = calculate_metrics(y_test, y_pred)
104
105        # ๋ฉ”ํŠธ๋ฆญ ๋กœ๊น…
106        mlflow.log_metrics(metrics)
107
108        # Confusion Matrix ๋กœ๊น…
109        fig = plot_confusion_matrix(y_test, y_pred, class_names)
110        mlflow.log_figure(fig, "confusion_matrix.png")
111        plt.close(fig)
112
113        # Feature Importance (ํ•ด๋‹นํ•˜๋Š” ๊ฒฝ์šฐ)
114        if hasattr(model, "feature_importances_"):
115            fig, ax = plt.subplots(figsize=(10, 6))
116            importance = model.feature_importances_
117            feature_names = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
118            indices = np.argsort(importance)[::-1]
119            ax.bar(range(len(importance)), importance[indices])
120            ax.set_xticks(range(len(importance)))
121            ax.set_xticklabels([feature_names[i] for i in indices], rotation=45)
122            ax.set_title("Feature Importance")
123            mlflow.log_figure(fig, "feature_importance.png")
124            plt.close(fig)
125
126        # ๋ชจ๋ธ ์ €์žฅ
127        signature = mlflow.models.infer_signature(X_train, model.predict(X_train))
128        mlflow.sklearn.log_model(model, "model", signature=signature)
129
130        # ํƒœ๊ทธ ์ถ”๊ฐ€
131        mlflow.set_tag("validated", "true")
132        mlflow.set_tag("dataset", "iris")
133
134        print(f"\n{model_name}:")
135        print(f"  Accuracy: {metrics['accuracy']:.4f}")
136        print(f"  F1 Score: {metrics['f1_macro']:.4f}")
137        print(f"  CV Mean: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")
138
139        return mlflow.active_run().info.run_id
140
141
142def main():
143    """๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜"""
144    # MLflow ์„ค์ •
145    setup_mlflow()
146
147    # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
148    X_train, X_test, y_train, y_test, class_names = load_data()
149
150    # ๋ชจ๋ธ ์ •์˜
151    models = [
152        (
153            "RandomForest",
154            RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42),
155            {"n_estimators": 100, "max_depth": 5}
156        ),
157        (
158            "GradientBoosting",
159            GradientBoostingClassifier(n_estimators=100, max_depth=3, random_state=42),
160            {"n_estimators": 100, "max_depth": 3}
161        ),
162        (
163            "LogisticRegression",
164            LogisticRegression(max_iter=200, random_state=42),
165            {"max_iter": 200}
166        )
167    ]
168
169    # ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๋กœ๊น…
170    run_ids = []
171    for model_name, model, params in models:
172        run_id = train_and_log(
173            model, model_name, params,
174            X_train, X_test, y_train, y_test,
175            class_names
176        )
177        run_ids.append((model_name, run_id))
178
179    # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
180    print("\n" + "=" * 50)
181    print("์‹คํ—˜ ์™„๋ฃŒ!")
182    print(f"MLflow UI์—์„œ ๊ฒฐ๊ณผ ํ™•์ธ: {TRACKING_URI}")
183    print("\n๋“ฑ๋ก๋œ ์‹คํ–‰:")
184    for name, run_id in run_ids:
185        print(f"  - {name}: {run_id}")
186
187
188if __name__ == "__main__":
189    main()