model_registry.py

Download
python 228 lines 6.5 KB
  1"""
  2MLflow Model Registry Example
  3=============================
  4
  5MLflow Model Registry๋ฅผ ์‚ฌ์šฉํ•œ ๋ชจ๋ธ ๋ฒ„์ „ ๊ด€๋ฆฌ ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.
  6
  7์‹คํ–‰ ๋ฐฉ๋ฒ•:
  8    # ๋จผ์ € tracking_example.py๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต/์ €์žฅํ•œ ํ›„
  9    python model_registry.py
 10"""
 11
 12import mlflow
 13from mlflow.tracking import MlflowClient
 14from sklearn.datasets import load_iris
 15from sklearn.model_selection import train_test_split
 16from sklearn.ensemble import RandomForestClassifier
 17from sklearn.metrics import accuracy_score
 18import os
 19
 20# MLflow ์„ค์ •
 21TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")
 22MODEL_NAME = "iris-classifier"
 23
 24
 25def setup():
 26    """MLflow ์„ค์ •"""
 27    mlflow.set_tracking_uri(TRACKING_URI)
 28    mlflow.set_experiment("model-registry-demo")
 29    return MlflowClient()
 30
 31
 32def train_and_register_model(client, version_tag: str):
 33    """๋ชจ๋ธ ํ•™์Šต ๋ฐ ๋ ˆ์ง€์ŠคํŠธ๋ฆฌ ๋“ฑ๋ก"""
 34    # ๋ฐ์ดํ„ฐ ์ค€๋น„
 35    iris = load_iris()
 36    X_train, X_test, y_train, y_test = train_test_split(
 37        iris.data, iris.target, test_size=0.2, random_state=42
 38    )
 39
 40    with mlflow.start_run(run_name=f"training-{version_tag}") as run:
 41        # ๋ชจ๋ธ ํ•™์Šต
 42        model = RandomForestClassifier(n_estimators=100, random_state=42)
 43        model.fit(X_train, y_train)
 44
 45        # ํ‰๊ฐ€
 46        accuracy = accuracy_score(y_test, model.predict(X_test))
 47        mlflow.log_metric("accuracy", accuracy)
 48
 49        # ๋ชจ๋ธ ์ €์žฅ ๋ฐ ๋“ฑ๋ก
 50        mlflow.sklearn.log_model(
 51            model,
 52            "model",
 53            registered_model_name=MODEL_NAME
 54        )
 55
 56        print(f"\n๋ชจ๋ธ ๋“ฑ๋ก ์™„๋ฃŒ: {MODEL_NAME}")
 57        print(f"  Run ID: {run.info.run_id}")
 58        print(f"  Accuracy: {accuracy:.4f}")
 59
 60        return run.info.run_id
 61
 62
 63def get_model_versions(client):
 64    """๋“ฑ๋ก๋œ ๋ชจ๋ธ ๋ฒ„์ „ ์กฐํšŒ"""
 65    print(f"\n{'='*50}")
 66    print(f"๋ชจ๋ธ '{MODEL_NAME}' ๋ฒ„์ „ ๋ชฉ๋ก:")
 67    print("="*50)
 68
 69    try:
 70        versions = client.search_model_versions(f"name='{MODEL_NAME}'")
 71        for v in versions:
 72            print(f"\n๋ฒ„์ „ {v.version}:")
 73            print(f"  ์ƒํƒœ: {v.current_stage}")
 74            print(f"  Run ID: {v.run_id}")
 75            print(f"  ์ƒ์„ฑ์ผ: {v.creation_timestamp}")
 76            if v.description:
 77                print(f"  ์„ค๋ช…: {v.description}")
 78        return versions
 79    except Exception as e:
 80        print(f"๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {e}")
 81        return []
 82
 83
 84def transition_to_staging(client, version: str):
 85    """๋ชจ๋ธ์„ Staging์œผ๋กœ ์ „ํ™˜"""
 86    client.transition_model_version_stage(
 87        name=MODEL_NAME,
 88        version=version,
 89        stage="Staging",
 90        archive_existing_versions=False
 91    )
 92    print(f"\n๋ชจ๋ธ v{version}์„ Staging์œผ๋กœ ์ „ํ™˜ํ–ˆ์Šต๋‹ˆ๋‹ค.")
 93
 94
 95def transition_to_production(client, version: str):
 96    """๋ชจ๋ธ์„ Production์œผ๋กœ ์ „ํ™˜"""
 97    client.transition_model_version_stage(
 98        name=MODEL_NAME,
 99        version=version,
100        stage="Production",
101        archive_existing_versions=True
102    )
103    print(f"\n๋ชจ๋ธ v{version}์„ Production์œผ๋กœ ์ „ํ™˜ํ–ˆ์Šต๋‹ˆ๋‹ค.")
104
105
106def update_model_description(client, version: str, description: str):
107    """๋ชจ๋ธ ์„ค๋ช… ์—…๋ฐ์ดํŠธ"""
108    client.update_model_version(
109        name=MODEL_NAME,
110        version=version,
111        description=description
112    )
113    print(f"\n๋ชจ๋ธ v{version} ์„ค๋ช…์„ ์—…๋ฐ์ดํŠธํ–ˆ์Šต๋‹ˆ๋‹ค.")
114
115
116def add_model_tag(client, version: str, key: str, value: str):
117    """๋ชจ๋ธ ํƒœ๊ทธ ์ถ”๊ฐ€"""
118    client.set_model_version_tag(
119        name=MODEL_NAME,
120        version=version,
121        key=key,
122        value=value
123    )
124    print(f"\n๋ชจ๋ธ v{version}์— ํƒœ๊ทธ ์ถ”๊ฐ€: {key}={value}")
125
126
127def load_model_by_stage(stage: str):
128    """์Šคํ…Œ์ด์ง€๋ณ„ ๋ชจ๋ธ ๋กœ๋“œ"""
129    try:
130        model = mlflow.sklearn.load_model(f"models:/{MODEL_NAME}/{stage}")
131        print(f"\n{stage} ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต!")
132        return model
133    except Exception as e:
134        print(f"\n{stage} ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
135        return None
136
137
138def demo_workflow(client):
139    """์ „์ฒด ์›Œํฌํ”Œ๋กœ์šฐ ๋ฐ๋ชจ"""
140    print("\n" + "="*60)
141    print("MLflow Model Registry ์›Œํฌํ”Œ๋กœ์šฐ ๋ฐ๋ชจ")
142    print("="*60)
143
144    # 1. ์ฒซ ๋ฒˆ์งธ ๋ชจ๋ธ ๋“ฑ๋ก
145    print("\n[1] ์ฒซ ๋ฒˆ์งธ ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๋“ฑ๋ก...")
146    train_and_register_model(client, "v1")
147
148    # 2. ๋ฒ„์ „ ์กฐํšŒ
149    versions = get_model_versions(client)
150    if not versions:
151        return
152
153    latest_version = max(v.version for v in versions)
154
155    # 3. ์„ค๋ช… ์ถ”๊ฐ€
156    print("\n[2] ๋ชจ๋ธ ์„ค๋ช… ์ถ”๊ฐ€...")
157    update_model_description(
158        client, latest_version,
159        "Initial model trained on Iris dataset with Random Forest"
160    )
161
162    # 4. ํƒœ๊ทธ ์ถ”๊ฐ€
163    print("\n[3] ๋ชจ๋ธ ํƒœ๊ทธ ์ถ”๊ฐ€...")
164    add_model_tag(client, latest_version, "validated", "true")
165    add_model_tag(client, latest_version, "dataset", "iris")
166
167    # 5. Staging ์ „ํ™˜
168    print("\n[4] Staging์œผ๋กœ ์ „ํ™˜...")
169    transition_to_staging(client, latest_version)
170
171    # 6. ๋‘ ๋ฒˆ์งธ ๋ชจ๋ธ ๋“ฑ๋ก
172    print("\n[5] ๋‘ ๋ฒˆ์งธ ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๋“ฑ๋ก (๊ฐœ์„  ๋ฒ„์ „)...")
173    train_and_register_model(client, "v2")
174
175    # 7. ๋ฒ„์ „ ์žฌ์กฐํšŒ
176    versions = get_model_versions(client)
177    new_latest = max(v.version for v in versions)
178
179    # 8. ์ƒˆ ๋ฒ„์ „์„ Staging์œผ๋กœ
180    print("\n[6] ์ƒˆ ๋ฒ„์ „์„ Staging์œผ๋กœ...")
181    transition_to_staging(client, new_latest)
182
183    # 9. Production ์Šน๊ฒฉ
184    print("\n[7] Production์œผ๋กœ ์Šน๊ฒฉ...")
185    transition_to_production(client, new_latest)
186
187    # 10. ์ตœ์ข… ์ƒํƒœ ํ™•์ธ
188    print("\n[8] ์ตœ์ข… ๋ชจ๋ธ ์ƒํƒœ:")
189    get_model_versions(client)
190
191    # 11. Production ๋ชจ๋ธ ๋กœ๋“œ ํ…Œ์ŠคํŠธ
192    print("\n[9] Production ๋ชจ๋ธ ๋กœ๋“œ ํ…Œ์ŠคํŠธ...")
193    model = load_model_by_stage("Production")
194    if model:
195        # ๊ฐ„๋‹จํ•œ ์˜ˆ์ธก ํ…Œ์ŠคํŠธ
196        iris = load_iris()
197        sample = iris.data[:3]
198        predictions = model.predict(sample)
199        print(f"  ์ƒ˜ํ”Œ ์˜ˆ์ธก ๊ฒฐ๊ณผ: {predictions}")
200        print(f"  ์‹ค์ œ ๋ ˆ์ด๋ธ”: {iris.target[:3]}")
201
202
203def main():
204    """๋ฉ”์ธ ํ•จ์ˆ˜"""
205    client = setup()
206
207    print("\nMLflow Model Registry ์˜ˆ์ œ")
208    print("="*50)
209    print("\n์˜ต์…˜:")
210    print("1. ์ƒˆ ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๋“ฑ๋ก")
211    print("2. ๋“ฑ๋ก๋œ ๋ชจ๋ธ ์กฐํšŒ")
212    print("3. ์ „์ฒด ์›Œํฌํ”Œ๋กœ์šฐ ๋ฐ๋ชจ")
213
214    choice = input("\n์„ ํƒ (1/2/3): ").strip()
215
216    if choice == "1":
217        train_and_register_model(client, "manual")
218    elif choice == "2":
219        get_model_versions(client)
220    elif choice == "3":
221        demo_workflow(client)
222    else:
223        print("์ž˜๋ชป๋œ ์„ ํƒ์ž…๋‹ˆ๋‹ค.")
224
225
226if __name__ == "__main__":
227    main()