1"""
2MLflow Model Registry Example
3=============================
4
5MLflow Model Registry๋ฅผ ์ฌ์ฉํ ๋ชจ๋ธ ๋ฒ์ ๊ด๋ฆฌ ์์ ์
๋๋ค.
6
7์คํ ๋ฐฉ๋ฒ:
8 # ๋จผ์ tracking_example.py๋ฅผ ์คํํ์ฌ ๋ชจ๋ธ์ ํ์ต/์ ์ฅํ ํ
9 python model_registry.py
10"""
11
12import mlflow
13from mlflow.tracking import MlflowClient
14from sklearn.datasets import load_iris
15from sklearn.model_selection import train_test_split
16from sklearn.ensemble import RandomForestClassifier
17from sklearn.metrics import accuracy_score
18import os
19
20# MLflow ์ค์
21TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")
22MODEL_NAME = "iris-classifier"
23
24
25def setup():
26 """MLflow ์ค์ """
27 mlflow.set_tracking_uri(TRACKING_URI)
28 mlflow.set_experiment("model-registry-demo")
29 return MlflowClient()
30
31
32def train_and_register_model(client, version_tag: str):
33 """๋ชจ๋ธ ํ์ต ๋ฐ ๋ ์ง์คํธ๋ฆฌ ๋ฑ๋ก"""
34 # ๋ฐ์ดํฐ ์ค๋น
35 iris = load_iris()
36 X_train, X_test, y_train, y_test = train_test_split(
37 iris.data, iris.target, test_size=0.2, random_state=42
38 )
39
40 with mlflow.start_run(run_name=f"training-{version_tag}") as run:
41 # ๋ชจ๋ธ ํ์ต
42 model = RandomForestClassifier(n_estimators=100, random_state=42)
43 model.fit(X_train, y_train)
44
45 # ํ๊ฐ
46 accuracy = accuracy_score(y_test, model.predict(X_test))
47 mlflow.log_metric("accuracy", accuracy)
48
49 # ๋ชจ๋ธ ์ ์ฅ ๋ฐ ๋ฑ๋ก
50 mlflow.sklearn.log_model(
51 model,
52 "model",
53 registered_model_name=MODEL_NAME
54 )
55
56 print(f"\n๋ชจ๋ธ ๋ฑ๋ก ์๋ฃ: {MODEL_NAME}")
57 print(f" Run ID: {run.info.run_id}")
58 print(f" Accuracy: {accuracy:.4f}")
59
60 return run.info.run_id
61
62
63def get_model_versions(client):
64 """๋ฑ๋ก๋ ๋ชจ๋ธ ๋ฒ์ ์กฐํ"""
65 print(f"\n{'='*50}")
66 print(f"๋ชจ๋ธ '{MODEL_NAME}' ๋ฒ์ ๋ชฉ๋ก:")
67 print("="*50)
68
69 try:
70 versions = client.search_model_versions(f"name='{MODEL_NAME}'")
71 for v in versions:
72 print(f"\n๋ฒ์ {v.version}:")
73 print(f" ์ํ: {v.current_stage}")
74 print(f" Run ID: {v.run_id}")
75 print(f" ์์ฑ์ผ: {v.creation_timestamp}")
76 if v.description:
77 print(f" ์ค๋ช
: {v.description}")
78 return versions
79 except Exception as e:
80 print(f"๋ชจ๋ธ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}")
81 return []
82
83
84def transition_to_staging(client, version: str):
85 """๋ชจ๋ธ์ Staging์ผ๋ก ์ ํ"""
86 client.transition_model_version_stage(
87 name=MODEL_NAME,
88 version=version,
89 stage="Staging",
90 archive_existing_versions=False
91 )
92 print(f"\n๋ชจ๋ธ v{version}์ Staging์ผ๋ก ์ ํํ์ต๋๋ค.")
93
94
95def transition_to_production(client, version: str):
96 """๋ชจ๋ธ์ Production์ผ๋ก ์ ํ"""
97 client.transition_model_version_stage(
98 name=MODEL_NAME,
99 version=version,
100 stage="Production",
101 archive_existing_versions=True
102 )
103 print(f"\n๋ชจ๋ธ v{version}์ Production์ผ๋ก ์ ํํ์ต๋๋ค.")
104
105
106def update_model_description(client, version: str, description: str):
107 """๋ชจ๋ธ ์ค๋ช
์
๋ฐ์ดํธ"""
108 client.update_model_version(
109 name=MODEL_NAME,
110 version=version,
111 description=description
112 )
113 print(f"\n๋ชจ๋ธ v{version} ์ค๋ช
์ ์
๋ฐ์ดํธํ์ต๋๋ค.")
114
115
116def add_model_tag(client, version: str, key: str, value: str):
117 """๋ชจ๋ธ ํ๊ทธ ์ถ๊ฐ"""
118 client.set_model_version_tag(
119 name=MODEL_NAME,
120 version=version,
121 key=key,
122 value=value
123 )
124 print(f"\n๋ชจ๋ธ v{version}์ ํ๊ทธ ์ถ๊ฐ: {key}={value}")
125
126
127def load_model_by_stage(stage: str):
128 """์คํ
์ด์ง๋ณ ๋ชจ๋ธ ๋ก๋"""
129 try:
130 model = mlflow.sklearn.load_model(f"models:/{MODEL_NAME}/{stage}")
131 print(f"\n{stage} ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต!")
132 return model
133 except Exception as e:
134 print(f"\n{stage} ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
135 return None
136
137
138def demo_workflow(client):
139 """์ ์ฒด ์ํฌํ๋ก์ฐ ๋ฐ๋ชจ"""
140 print("\n" + "="*60)
141 print("MLflow Model Registry ์ํฌํ๋ก์ฐ ๋ฐ๋ชจ")
142 print("="*60)
143
144 # 1. ์ฒซ ๋ฒ์งธ ๋ชจ๋ธ ๋ฑ๋ก
145 print("\n[1] ์ฒซ ๋ฒ์งธ ๋ชจ๋ธ ํ์ต ๋ฐ ๋ฑ๋ก...")
146 train_and_register_model(client, "v1")
147
148 # 2. ๋ฒ์ ์กฐํ
149 versions = get_model_versions(client)
150 if not versions:
151 return
152
153 latest_version = max(v.version for v in versions)
154
155 # 3. ์ค๋ช
์ถ๊ฐ
156 print("\n[2] ๋ชจ๋ธ ์ค๋ช
์ถ๊ฐ...")
157 update_model_description(
158 client, latest_version,
159 "Initial model trained on Iris dataset with Random Forest"
160 )
161
162 # 4. ํ๊ทธ ์ถ๊ฐ
163 print("\n[3] ๋ชจ๋ธ ํ๊ทธ ์ถ๊ฐ...")
164 add_model_tag(client, latest_version, "validated", "true")
165 add_model_tag(client, latest_version, "dataset", "iris")
166
167 # 5. Staging ์ ํ
168 print("\n[4] Staging์ผ๋ก ์ ํ...")
169 transition_to_staging(client, latest_version)
170
171 # 6. ๋ ๋ฒ์งธ ๋ชจ๋ธ ๋ฑ๋ก
172 print("\n[5] ๋ ๋ฒ์งธ ๋ชจ๋ธ ํ์ต ๋ฐ ๋ฑ๋ก (๊ฐ์ ๋ฒ์ )...")
173 train_and_register_model(client, "v2")
174
175 # 7. ๋ฒ์ ์ฌ์กฐํ
176 versions = get_model_versions(client)
177 new_latest = max(v.version for v in versions)
178
179 # 8. ์ ๋ฒ์ ์ Staging์ผ๋ก
180 print("\n[6] ์ ๋ฒ์ ์ Staging์ผ๋ก...")
181 transition_to_staging(client, new_latest)
182
183 # 9. Production ์น๊ฒฉ
184 print("\n[7] Production์ผ๋ก ์น๊ฒฉ...")
185 transition_to_production(client, new_latest)
186
187 # 10. ์ต์ข
์ํ ํ์ธ
188 print("\n[8] ์ต์ข
๋ชจ๋ธ ์ํ:")
189 get_model_versions(client)
190
191 # 11. Production ๋ชจ๋ธ ๋ก๋ ํ
์คํธ
192 print("\n[9] Production ๋ชจ๋ธ ๋ก๋ ํ
์คํธ...")
193 model = load_model_by_stage("Production")
194 if model:
195 # ๊ฐ๋จํ ์์ธก ํ
์คํธ
196 iris = load_iris()
197 sample = iris.data[:3]
198 predictions = model.predict(sample)
199 print(f" ์ํ ์์ธก ๊ฒฐ๊ณผ: {predictions}")
200 print(f" ์ค์ ๋ ์ด๋ธ: {iris.target[:3]}")
201
202
203def main():
204 """๋ฉ์ธ ํจ์"""
205 client = setup()
206
207 print("\nMLflow Model Registry ์์ ")
208 print("="*50)
209 print("\n์ต์
:")
210 print("1. ์ ๋ชจ๋ธ ํ์ต ๋ฐ ๋ฑ๋ก")
211 print("2. ๋ฑ๋ก๋ ๋ชจ๋ธ ์กฐํ")
212 print("3. ์ ์ฒด ์ํฌํ๋ก์ฐ ๋ฐ๋ชจ")
213
214 choice = input("\n์ ํ (1/2/3): ").strip()
215
216 if choice == "1":
217 train_and_register_model(client, "manual")
218 elif choice == "2":
219 get_model_versions(client)
220 elif choice == "3":
221 demo_workflow(client)
222 else:
223 print("์๋ชป๋ ์ ํ์
๋๋ค.")
224
225
226if __name__ == "__main__":
227 main()