drift_detection.py

Download
python 409 lines 14.2 KB
  1"""
  2Drift Detection Example
  3=======================
  4
  5Evidently AI๋ฅผ ์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€ ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.
  6
  7์‹คํ–‰ ๋ฐฉ๋ฒ•:
  8    pip install evidently pandas numpy scikit-learn
  9    python drift_detection.py
 10"""
 11
 12import pandas as pd
 13import numpy as np
 14from sklearn.datasets import make_classification
 15from scipy import stats
 16from datetime import datetime, timedelta
 17from typing import Dict, Any, Tuple
 18import warnings
 19warnings.filterwarnings("ignore")
 20
 21# Evidently ์ž„ํฌํŠธ (์„ ํƒ์ )
 22try:
 23    from evidently import ColumnMapping
 24    from evidently.report import Report
 25    from evidently.metric_preset import DataDriftPreset
 26    from evidently.metrics import DatasetDriftMetric, ColumnDriftMetric
 27    from evidently.test_suite import TestSuite
 28    from evidently.test_preset import DataDriftTestPreset
 29    EVIDENTLY_AVAILABLE = True
 30except ImportError:
 31    EVIDENTLY_AVAILABLE = False
 32    print("Evidently๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ ํ†ต๊ณ„ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
 33
 34
 35# ============================================================
 36# ๊ธฐ๋ณธ ํ†ต๊ณ„์  ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€
 37# ============================================================
 38
 39class StatisticalDriftDetector:
 40    """ํ†ต๊ณ„์  ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•œ ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€"""
 41
 42    def __init__(self, significance_level: float = 0.05):
 43        self.significance_level = significance_level
 44
 45    def ks_test(self, reference: np.ndarray, current: np.ndarray) -> Dict[str, Any]:
 46        """Kolmogorov-Smirnov ๊ฒ€์ •"""
 47        statistic, p_value = stats.ks_2samp(reference, current)
 48        return {
 49            "test": "ks",
 50            "statistic": float(statistic),
 51            "p_value": float(p_value),
 52            "drift_detected": p_value < self.significance_level
 53        }
 54
 55    def psi(self, reference: np.ndarray, current: np.ndarray, n_bins: int = 10) -> float:
 56        """Population Stability Index"""
 57        # ํžˆ์Šคํ† ๊ทธ๋žจ ์ƒ์„ฑ
 58        min_val = min(reference.min(), current.min())
 59        max_val = max(reference.max(), current.max())
 60        bins = np.linspace(min_val, max_val, n_bins + 1)
 61
 62        ref_counts, _ = np.histogram(reference, bins=bins)
 63        cur_counts, _ = np.histogram(current, bins=bins)
 64
 65        # ๋น„์œจ๋กœ ๋ณ€ํ™˜
 66        ref_pct = (ref_counts + 1) / (len(reference) + n_bins)
 67        cur_pct = (cur_counts + 1) / (len(current) + n_bins)
 68
 69        # PSI ๊ณ„์‚ฐ
 70        psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
 71        return float(psi)
 72
 73    def wasserstein_distance(self, reference: np.ndarray, current: np.ndarray) -> float:
 74        """Wasserstein ๊ฑฐ๋ฆฌ"""
 75        return float(stats.wasserstein_distance(reference, current))
 76
 77    def detect_column_drift(
 78        self,
 79        reference: pd.DataFrame,
 80        current: pd.DataFrame,
 81        column: str
 82    ) -> Dict[str, Any]:
 83        """๋‹จ์ผ ์ปฌ๋Ÿผ ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€"""
 84        ref_col = reference[column].dropna().values
 85        cur_col = current[column].dropna().values
 86
 87        ks_result = self.ks_test(ref_col, cur_col)
 88        psi_value = self.psi(ref_col, cur_col)
 89        wasserstein = self.wasserstein_distance(ref_col, cur_col)
 90
 91        # PSI ํ•ด์„
 92        if psi_value < 0.1:
 93            psi_status = "no_drift"
 94        elif psi_value < 0.2:
 95            psi_status = "slight_drift"
 96        else:
 97            psi_status = "significant_drift"
 98
 99        return {
100            "column": column,
101            "ks_test": ks_result,
102            "psi": {
103                "value": psi_value,
104                "status": psi_status
105            },
106            "wasserstein_distance": wasserstein,
107            "drift_detected": ks_result["drift_detected"] or psi_value >= 0.2
108        }
109
110    def detect_dataset_drift(
111        self,
112        reference: pd.DataFrame,
113        current: pd.DataFrame,
114        numerical_columns: list
115    ) -> Dict[str, Any]:
116        """๋ฐ์ดํ„ฐ์…‹ ์ „์ฒด ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€"""
117        results = {
118            "timestamp": datetime.now().isoformat(),
119            "columns": {},
120            "summary": {
121                "total_columns": len(numerical_columns),
122                "drifted_columns": 0,
123                "drift_detected": False
124            }
125        }
126
127        for col in numerical_columns:
128            if col in reference.columns and col in current.columns:
129                col_result = self.detect_column_drift(reference, current, col)
130                results["columns"][col] = col_result
131                if col_result["drift_detected"]:
132                    results["summary"]["drifted_columns"] += 1
133
134        drift_share = results["summary"]["drifted_columns"] / results["summary"]["total_columns"]
135        results["summary"]["drift_share"] = drift_share
136        results["summary"]["drift_detected"] = drift_share > 0.5
137
138        return results
139
140
141# ============================================================
142# Evidently ๊ธฐ๋ฐ˜ ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€
143# ============================================================
144
145class EvidentlyDriftDetector:
146    """Evidently AI๋ฅผ ์‚ฌ์šฉํ•œ ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€"""
147
148    def __init__(self):
149        if not EVIDENTLY_AVAILABLE:
150            raise ImportError("Evidently is not installed")
151
152    def create_report(
153        self,
154        reference: pd.DataFrame,
155        current: pd.DataFrame,
156        column_mapping: ColumnMapping = None
157    ) -> Report:
158        """๋“œ๋ฆฌํ”„ํŠธ ๋ฆฌํฌํŠธ ์ƒ์„ฑ"""
159        report = Report(metrics=[
160            DatasetDriftMetric(),
161            DataDriftPreset()
162        ])
163
164        report.run(
165            reference_data=reference,
166            current_data=current,
167            column_mapping=column_mapping
168        )
169
170        return report
171
172    def run_tests(
173        self,
174        reference: pd.DataFrame,
175        current: pd.DataFrame,
176        column_mapping: ColumnMapping = None
177    ) -> TestSuite:
178        """๋“œ๋ฆฌํ”„ํŠธ ํ…Œ์ŠคํŠธ ์‹คํ–‰"""
179        test_suite = TestSuite(tests=[
180            DataDriftTestPreset()
181        ])
182
183        test_suite.run(
184            reference_data=reference,
185            current_data=current,
186            column_mapping=column_mapping
187        )
188
189        return test_suite
190
191    def get_drift_summary(self, report: Report) -> Dict[str, Any]:
192        """๋ฆฌํฌํŠธ์—์„œ ๋“œ๋ฆฌํ”„ํŠธ ์š”์•ฝ ์ถ”์ถœ"""
193        result = report.as_dict()
194
195        # DatasetDriftMetric ๊ฒฐ๊ณผ ์ถ”์ถœ
196        for metric in result.get("metrics", []):
197            if "DatasetDriftMetric" in str(metric.get("metric", "")):
198                drift_result = metric.get("result", {})
199                return {
200                    "dataset_drift": drift_result.get("dataset_drift", False),
201                    "drift_share": drift_result.get("drift_share", 0),
202                    "number_of_columns": drift_result.get("number_of_columns", 0),
203                    "number_of_drifted_columns": drift_result.get("number_of_drifted_columns", 0)
204                }
205
206        return {"error": "Could not extract drift summary"}
207
208
209# ============================================================
210# ๋“œ๋ฆฌํ”„ํŠธ ๋ชจ๋‹ˆํ„ฐ๋ง ์‹œ์Šคํ…œ
211# ============================================================
212
213class DriftMonitor:
214    """๋“œ๋ฆฌํ”„ํŠธ ๋ชจ๋‹ˆํ„ฐ๋ง ์‹œ์Šคํ…œ"""
215
216    def __init__(
217        self,
218        reference_data: pd.DataFrame,
219        numerical_columns: list,
220        alert_threshold: float = 0.3
221    ):
222        self.reference_data = reference_data
223        self.numerical_columns = numerical_columns
224        self.alert_threshold = alert_threshold
225        self.detector = StatisticalDriftDetector()
226        self.history = []
227
228    def check(self, current_data: pd.DataFrame) -> Dict[str, Any]:
229        """๋“œ๋ฆฌํ”„ํŠธ ์ฒดํฌ"""
230        result = self.detector.detect_dataset_drift(
231            self.reference_data,
232            current_data,
233            self.numerical_columns
234        )
235
236        # ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€
237        self.history.append({
238            "timestamp": result["timestamp"],
239            "drift_share": result["summary"]["drift_share"],
240            "drift_detected": result["summary"]["drift_detected"]
241        })
242
243        # ์•Œ๋ฆผ ์ƒ์„ฑ
244        result["alerts"] = self._generate_alerts(result)
245
246        return result
247
248    def _generate_alerts(self, result: Dict) -> list:
249        """์•Œ๋ฆผ ์ƒ์„ฑ"""
250        alerts = []
251
252        if result["summary"]["drift_detected"]:
253            alerts.append({
254                "level": "critical",
255                "message": f"Dataset drift detected: {result['summary']['drift_share']:.1%} of columns drifted"
256            })
257
258        for col, col_result in result["columns"].items():
259            if col_result["drift_detected"]:
260                psi = col_result["psi"]["value"]
261                if psi >= 0.25:
262                    alerts.append({
263                        "level": "warning",
264                        "message": f"Significant drift in '{col}': PSI={psi:.3f}"
265                    })
266
267        return alerts
268
269    def get_trend(self, window_size: int = 10) -> Dict[str, Any]:
270        """๋“œ๋ฆฌํ”„ํŠธ ํŠธ๋ Œ๋“œ ๋ถ„์„"""
271        if len(self.history) < 2:
272            return {"message": "Not enough data for trend analysis"}
273
274        recent = self.history[-window_size:]
275        drift_shares = [h["drift_share"] for h in recent]
276
277        return {
278            "window_size": len(recent),
279            "avg_drift_share": np.mean(drift_shares),
280            "max_drift_share": max(drift_shares),
281            "drift_count": sum(1 for h in recent if h["drift_detected"]),
282            "trend": "increasing" if len(drift_shares) > 1 and drift_shares[-1] > drift_shares[0] else "stable"
283        }
284
285
286# ============================================================
287# ์˜ˆ์ œ ์‹คํ–‰
288# ============================================================
289
290def generate_sample_data(n_samples: int = 1000, drift: bool = False) -> pd.DataFrame:
291    """์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ"""
292    np.random.seed(42 if not drift else 123)
293
294    data = {
295        "feature_1": np.random.normal(0, 1, n_samples),
296        "feature_2": np.random.normal(5, 2, n_samples),
297        "feature_3": np.random.exponential(2, n_samples),
298        "feature_4": np.random.uniform(0, 10, n_samples)
299    }
300
301    if drift:
302        # ์ผ๋ถ€ ํ”ผ์ฒ˜์— ๋“œ๋ฆฌํ”„ํŠธ ์ถ”๊ฐ€
303        data["feature_1"] = np.random.normal(0.5, 1.2, n_samples)  # ํ‰๊ท , ๋ถ„์‚ฐ ๋ณ€ํ™”
304        data["feature_3"] = np.random.exponential(3, n_samples)    # ๋ถ„ํฌ ๋ณ€ํ™”
305
306    return pd.DataFrame(data)
307
308
309def main():
310    """๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜"""
311    print("="*60)
312    print("๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€ ์˜ˆ์ œ")
313    print("="*60)
314
315    # 1. ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
316    print("\n[1] ๋ฐ์ดํ„ฐ ์ƒ์„ฑ...")
317    reference_data = generate_sample_data(1000, drift=False)
318    current_data_no_drift = generate_sample_data(500, drift=False)
319    current_data_with_drift = generate_sample_data(500, drift=True)
320
321    print(f"  ์ฐธ์กฐ ๋ฐ์ดํ„ฐ: {len(reference_data)} ์ƒ˜ํ”Œ")
322    print(f"  ํ˜„์žฌ ๋ฐ์ดํ„ฐ (๋“œ๋ฆฌํ”„ํŠธ ์—†์Œ): {len(current_data_no_drift)} ์ƒ˜ํ”Œ")
323    print(f"  ํ˜„์žฌ ๋ฐ์ดํ„ฐ (๋“œ๋ฆฌํ”„ํŠธ ์žˆ์Œ): {len(current_data_with_drift)} ์ƒ˜ํ”Œ")
324
325    # 2. ๊ธฐ๋ณธ ํ†ต๊ณ„์  ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€
326    print("\n[2] ํ†ต๊ณ„์  ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€...")
327    detector = StatisticalDriftDetector()
328    numerical_cols = ["feature_1", "feature_2", "feature_3", "feature_4"]
329
330    # ๋“œ๋ฆฌํ”„ํŠธ ์—†๋Š” ๋ฐ์ดํ„ฐ
331    print("\n  --- ๋“œ๋ฆฌํ”„ํŠธ ์—†๋Š” ๋ฐ์ดํ„ฐ ---")
332    result_no_drift = detector.detect_dataset_drift(
333        reference_data, current_data_no_drift, numerical_cols
334    )
335    print(f"  ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€: {result_no_drift['summary']['drift_detected']}")
336    print(f"  ๋“œ๋ฆฌํ”„ํŠธ ๋น„์œจ: {result_no_drift['summary']['drift_share']:.1%}")
337
338    # ๋“œ๋ฆฌํ”„ํŠธ ์žˆ๋Š” ๋ฐ์ดํ„ฐ
339    print("\n  --- ๋“œ๋ฆฌํ”„ํŠธ ์žˆ๋Š” ๋ฐ์ดํ„ฐ ---")
340    result_with_drift = detector.detect_dataset_drift(
341        reference_data, current_data_with_drift, numerical_cols
342    )
343    print(f"  ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€: {result_with_drift['summary']['drift_detected']}")
344    print(f"  ๋“œ๋ฆฌํ”„ํŠธ ๋น„์œจ: {result_with_drift['summary']['drift_share']:.1%}")
345
346    # ์ปฌ๋Ÿผ๋ณ„ ์ƒ์„ธ ๊ฒฐ๊ณผ
347    print("\n  ์ปฌ๋Ÿผ๋ณ„ ์ƒ์„ธ:")
348    for col, col_result in result_with_drift["columns"].items():
349        drift_status = "DRIFT" if col_result["drift_detected"] else "OK"
350        psi = col_result["psi"]["value"]
351        print(f"    {col}: PSI={psi:.4f} [{drift_status}]")
352
353    # 3. Evidently ๊ธฐ๋ฐ˜ ๊ฐ์ง€ (์„ค์น˜๋œ ๊ฒฝ์šฐ)
354    if EVIDENTLY_AVAILABLE:
355        print("\n[3] Evidently ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€...")
356        evidently_detector = EvidentlyDriftDetector()
357
358        report = evidently_detector.create_report(
359            reference_data, current_data_with_drift
360        )
361
362        summary = evidently_detector.get_drift_summary(report)
363        print(f"  Dataset Drift: {summary.get('dataset_drift', 'N/A')}")
364        print(f"  Drift Share: {summary.get('drift_share', 0):.1%}")
365        print(f"  Drifted Columns: {summary.get('number_of_drifted_columns', 0)}/{summary.get('number_of_columns', 0)}")
366
367        # HTML ๋ฆฌํฌํŠธ ์ €์žฅ
368        report.save_html("drift_report.html")
369        print("\n  HTML ๋ฆฌํฌํŠธ ์ €์žฅ: drift_report.html")
370    else:
371        print("\n[3] Evidently ์„ค์น˜ ํ•„์š” (pip install evidently)")
372
373    # 4. ๋ชจ๋‹ˆํ„ฐ๋ง ์‹œ์Šคํ…œ ์‹œ๋ฎฌ๋ ˆ์ด์…˜
374    print("\n[4] ๋ชจ๋‹ˆํ„ฐ๋ง ์‹œ์Šคํ…œ ์‹œ๋ฎฌ๋ ˆ์ด์…˜...")
375    monitor = DriftMonitor(
376        reference_data=reference_data,
377        numerical_columns=numerical_cols
378    )
379
380    # ์—ฌ๋Ÿฌ ์‹œ์  ๋ฐ์ดํ„ฐ๋กœ ์ฒดํฌ
381    for i in range(5):
382        # ์‹œ๊ฐ„์ด ์ง€๋‚จ์— ๋”ฐ๋ผ ์ ์ง„์  ๋“œ๋ฆฌํ”„ํŠธ
383        drift_factor = i * 0.1
384        test_data = reference_data.copy()
385        test_data["feature_1"] = test_data["feature_1"] + drift_factor
386        test_data["feature_3"] = test_data["feature_3"] * (1 + drift_factor)
387
388        result = monitor.check(test_data.sample(500))
389        print(f"\n  ์‹œ์  {i+1}:")
390        print(f"    ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€: {result['summary']['drift_detected']}")
391        print(f"    ๋“œ๋ฆฌํ”„ํŠธ ๋น„์œจ: {result['summary']['drift_share']:.1%}")
392        if result["alerts"]:
393            for alert in result["alerts"]:
394                print(f"    [{alert['level'].upper()}] {alert['message']}")
395
396    # ํŠธ๋ Œ๋“œ ๋ถ„์„
397    trend = monitor.get_trend()
398    print(f"\n  ํŠธ๋ Œ๋“œ ๋ถ„์„:")
399    print(f"    ํ‰๊ท  ๋“œ๋ฆฌํ”„ํŠธ ๋น„์œจ: {trend['avg_drift_share']:.1%}")
400    print(f"    ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€ ํšŸ์ˆ˜: {trend['drift_count']}")
401    print(f"    ํŠธ๋ Œ๋“œ: {trend['trend']}")
402
403    print("\n" + "="*60)
404    print("์˜ˆ์ œ ์™„๋ฃŒ!")
405
406
407if __name__ == "__main__":
408    main()