onnx_inference.py

Download
python 553 lines 17.7 KB
  1#!/usr/bin/env python3
  2"""
  3ONNX Runtime ๊ธฐ๋ฐ˜ Edge AI ์ถ”๋ก 
  4์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ฐ ๊ฐ์ฒด ๊ฒ€์ถœ ์˜ˆ์ œ
  5
  6์ฐธ๊ณ : content/ko/IoT_Embedded/09_Edge_AI_ONNX.md
  7"""
  8
  9import numpy as np
 10import time
 11from typing import Optional, Tuple, List, Dict
 12import os
 13
 14# ONNX Runtime ์„ค์น˜ ํ™•์ธ
 15try:
 16    import onnxruntime as ort
 17    HAS_ONNX = True
 18except ImportError:
 19    HAS_ONNX = False
 20    print("๊ฒฝ๊ณ : onnxruntime์ด ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
 21    print("์„ค์น˜: pip install onnxruntime")
 22
 23# OpenCV ์„ค์น˜ ํ™•์ธ
 24try:
 25    import cv2
 26    HAS_CV2 = True
 27except ImportError:
 28    HAS_CV2 = False
 29    print("๊ฒฝ๊ณ : opencv-python์ด ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
 30    print("์„ค์น˜: pip install opencv-python")
 31
 32# PIL ์„ค์น˜ ํ™•์ธ
 33try:
 34    from PIL import Image
 35    HAS_PIL = True
 36except ImportError:
 37    HAS_PIL = False
 38    print("๊ฒฝ๊ณ : Pillow๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
 39    print("์„ค์น˜: pip install Pillow")
 40
 41
 42# === ONNX ๋ชจ๋ธ ๋ž˜ํผ ===
 43
 44class ONNXModel:
 45    """ONNX ๋ชจ๋ธ ๊ธฐ๋ณธ ๋ž˜ํผ"""
 46
 47    def __init__(self, model_path: str, providers: Optional[List[str]] = None):
 48        if not HAS_ONNX:
 49            raise ImportError("onnxruntime์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
 50
 51        if not os.path.exists(model_path):
 52            raise FileNotFoundError(f"๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {model_path}")
 53
 54        if providers is None:
 55            # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ”„๋กœ๋ฐ”์ด๋” ์ž๋™ ์„ ํƒ
 56            available = ort.get_available_providers()
 57            if 'CUDAExecutionProvider' in available:
 58                providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 59            else:
 60                providers = ['CPUExecutionProvider']
 61
 62        # ์„ธ์…˜ ์˜ต์…˜ ์„ค์ •
 63        sess_options = ort.SessionOptions()
 64        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
 65        sess_options.intra_op_num_threads = 4  # CPU ์Šค๋ ˆ๋“œ ์ˆ˜
 66
 67        # ์„ธ์…˜ ์ƒ์„ฑ
 68        self.session = ort.InferenceSession(
 69            model_path,
 70            sess_options=sess_options,
 71            providers=providers
 72        )
 73
 74        # ์ž…์ถœ๋ ฅ ์ •๋ณด
 75        self.input_name = self.session.get_inputs()[0].name
 76        self.input_shape = self.session.get_inputs()[0].shape
 77        self.input_type = self.session.get_inputs()[0].type
 78        self.output_name = self.session.get_outputs()[0].name
 79
 80        print(f"๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ: {model_path}")
 81        print(f"  ํ”„๋กœ๋ฐ”์ด๋”: {self.session.get_providers()}")
 82        print(f"  ์ž…๋ ฅ: {self.input_name} {self.input_shape}")
 83        print(f"  ์ถœ๋ ฅ: {self.output_name}")
 84
 85    def get_input_shape(self) -> list:
 86        """์ž…๋ ฅ ํ˜•ํƒœ ๋ฐ˜ํ™˜"""
 87        return self.input_shape
 88
 89    def predict(self, input_data: np.ndarray) -> np.ndarray:
 90        """์ถ”๋ก  ์ˆ˜ํ–‰"""
 91        outputs = self.session.run(
 92            [self.output_name],
 93            {self.input_name: input_data}
 94        )
 95        return outputs[0]
 96
 97    def benchmark(self, num_iterations: int = 100) -> Dict[str, float]:
 98        """์„ฑ๋Šฅ ๋ฒค์น˜๋งˆํฌ"""
 99        # ๋”๋ฏธ ์ž…๋ ฅ ์ƒ์„ฑ
100        dummy_shape = [1 if x == 'batch' or x == 'N' or x is None else x
101                      for x in self.input_shape]
102        dummy_input = np.random.randn(*dummy_shape).astype(np.float32)
103
104        # ์›Œ๋ฐ์—…
105        for _ in range(10):
106            self.predict(dummy_input)
107
108        # ์ธก์ •
109        times = []
110        for _ in range(num_iterations):
111            start = time.perf_counter()
112            self.predict(dummy_input)
113            elapsed = (time.perf_counter() - start) * 1000  # ms
114            times.append(elapsed)
115
116        times = np.array(times)
117
118        results = {
119            "mean_ms": float(np.mean(times)),
120            "std_ms": float(np.std(times)),
121            "min_ms": float(np.min(times)),
122            "max_ms": float(np.max(times)),
123            "fps": 1000.0 / np.mean(times)
124        }
125
126        return results
127
128
129# === ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ===
130
131class ImageClassifier(ONNXModel):
132    """ONNX ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ"""
133
134    # ImageNet ํด๋ž˜์Šค (์ƒ์œ„ 10๊ฐœ๋งŒ ์˜ˆ์‹œ)
135    IMAGENET_CLASSES = [
136        'tench', 'goldfish', 'great_white_shark', 'tiger_shark',
137        'hammerhead', 'electric_ray', 'stingray', 'cock', 'hen', 'ostrich'
138        # ... ์‹ค์ œ๋กœ๋Š” 1000๊ฐœ ํด๋ž˜์Šค
139    ]
140
141    def __init__(self, model_path: str, labels_path: Optional[str] = None):
142        super().__init__(model_path)
143
144        # ๋ ˆ์ด๋ธ” ๋กœ๋“œ (์žˆ๋Š” ๊ฒฝ์šฐ)
145        if labels_path and os.path.exists(labels_path):
146            with open(labels_path, 'r') as f:
147                self.labels = [line.strip() for line in f]
148        else:
149            self.labels = self.IMAGENET_CLASSES
150
151        # ์ž…๋ ฅ ํฌ๊ธฐ ์ถ”์ถœ
152        self.input_height = self.input_shape[2] if len(self.input_shape) > 2 else 224
153        self.input_width = self.input_shape[3] if len(self.input_shape) > 3 else 224
154
155    def preprocess_image(self, image_path: str) -> np.ndarray:
156        """์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ (PIL ์‚ฌ์šฉ)"""
157        if not HAS_PIL:
158            raise ImportError("Pillow๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
159
160        # ์ด๋ฏธ์ง€ ๋กœ๋“œ
161        image = Image.open(image_path).convert('RGB')
162
163        # ๋ฆฌ์‚ฌ์ด์ฆˆ
164        image = image.resize((self.input_width, self.input_height))
165
166        # NumPy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
167        img_array = np.array(image).astype(np.float32)
168
169        # ์ •๊ทœํ™” (ImageNet ํ‘œ์ค€)
170        mean = np.array([0.485, 0.456, 0.406]) * 255
171        std = np.array([0.229, 0.224, 0.225]) * 255
172        img_array = (img_array - mean) / std
173
174        # HWC to CHW
175        img_array = img_array.transpose(2, 0, 1)
176
177        # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
178        img_array = np.expand_dims(img_array, axis=0)
179
180        return img_array
181
182    def classify(self, image_path: str, top_k: int = 5) -> List[Tuple[str, float]]:
183        """์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜"""
184        # ์ „์ฒ˜๋ฆฌ
185        input_data = self.preprocess_image(image_path)
186
187        # ์ถ”๋ก 
188        start = time.perf_counter()
189        output = self.predict(input_data)
190        inference_time = (time.perf_counter() - start) * 1000
191
192        # Softmax
193        probs = self._softmax(output[0])
194
195        # Top-K ๊ฒฐ๊ณผ
196        top_indices = np.argsort(probs)[-top_k:][::-1]
197
198        results = []
199        for idx in top_indices:
200            label = self.labels[idx] if idx < len(self.labels) else f"class_{idx}"
201            results.append((label, float(probs[idx])))
202
203        print(f"์ถ”๋ก  ์‹œ๊ฐ„: {inference_time:.2f}ms")
204
205        return results
206
207    @staticmethod
208    def _softmax(x: np.ndarray) -> np.ndarray:
209        """Softmax ํ•จ์ˆ˜"""
210        exp_x = np.exp(x - np.max(x))
211        return exp_x / exp_x.sum()
212
213
214# === ๊ฐ์ฒด ๊ฒ€์ถœ ๋ชจ๋ธ (YOLO) ===
215
216class YOLODetector:
217    """YOLO ONNX ๊ฐ์ฒด ๊ฒ€์ถœ๊ธฐ"""
218
219    # COCO ๋ฐ์ดํ„ฐ์…‹ 80๊ฐœ ํด๋ž˜์Šค
220    COCO_CLASSES = [
221        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
222        'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
223        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
224        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
225        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
226        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
227        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
228        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
229        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
230        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
231        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
232        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
233        'scissors', 'teddy bear', 'hair drier', 'toothbrush'
234    ]
235
236    def __init__(self, model_path: str, conf_threshold: float = 0.5,
237                 iou_threshold: float = 0.45):
238        if not HAS_ONNX:
239            raise ImportError("onnxruntime์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
240
241        if not os.path.exists(model_path):
242            # ๋ชจ๋ธ์ด ์—†๋Š” ๊ฒฝ์šฐ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ ๋ชจ๋“œ
243            print(f"๊ฒฝ๊ณ : ๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {model_path}")
244            print("์‹œ๋ฎฌ๋ ˆ์ด์…˜ ๋ชจ๋“œ๋กœ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
245            self.simulation_mode = True
246            self.input_height = 640
247            self.input_width = 640
248            return
249
250        self.simulation_mode = False
251
252        # ONNX ์„ธ์…˜ ์ƒ์„ฑ
253        sess_options = ort.SessionOptions()
254        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
255
256        self.session = ort.InferenceSession(
257            model_path,
258            sess_options=sess_options,
259            providers=['CPUExecutionProvider']
260        )
261
262        self.conf_threshold = conf_threshold
263        self.iou_threshold = iou_threshold
264
265        # ์ž…๋ ฅ ์ •๋ณด
266        input_info = self.session.get_inputs()[0]
267        self.input_name = input_info.name
268        self.input_shape = input_info.shape
269        self.input_height = self.input_shape[2]
270        self.input_width = self.input_shape[3]
271
272        print(f"YOLO ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
273        print(f"  ์ž…๋ ฅ ํฌ๊ธฐ: {self.input_width}x{self.input_height}")
274
275    def preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float]]:
276        """์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ"""
277        orig_height, orig_width = image.shape[:2]
278
279        # ๋ฆฌ์‚ฌ์ด์ฆˆ
280        resized = cv2.resize(image, (self.input_width, self.input_height))
281
282        # BGR to RGB, HWC to CHW
283        input_data = resized[:, :, ::-1].transpose(2, 0, 1)
284
285        # ์ •๊ทœํ™” (0-1)
286        input_data = input_data.astype(np.float32) / 255.0
287
288        # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
289        input_data = np.expand_dims(input_data, axis=0)
290
291        # ์Šค์ผ€์ผ ๋น„์œจ ์ €์žฅ
292        scale = (orig_width / self.input_width, orig_height / self.input_height)
293
294        return input_data, scale
295
296    def detect(self, image: np.ndarray) -> List[Dict]:
297        """๊ฐ์ฒด ๊ฒ€์ถœ"""
298        if self.simulation_mode:
299            # ์‹œ๋ฎฌ๋ ˆ์ด์…˜: ๋žœ๋ค ๊ฒ€์ถœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
300            print("์‹œ๋ฎฌ๋ ˆ์ด์…˜ ๋ชจ๋“œ: ๋žœ๋ค ๊ฒ€์ถœ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.")
301            return self._simulate_detection(image)
302
303        if not HAS_CV2:
304            raise ImportError("opencv-python์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
305
306        # ์ „์ฒ˜๋ฆฌ
307        input_data, scale = self.preprocess(image)
308
309        # ์ถ”๋ก 
310        start = time.perf_counter()
311        outputs = self.session.run(None, {self.input_name: input_data})
312        inference_time = (time.perf_counter() - start) * 1000
313
314        # ํ›„์ฒ˜๋ฆฌ
315        detections = self.postprocess(outputs[0], scale)
316
317        print(f"์ถ”๋ก  ์‹œ๊ฐ„: {inference_time:.2f}ms")
318        print(f"๊ฒ€์ถœ๋œ ๊ฐ์ฒด: {len(detections)}๊ฐœ")
319
320        return detections
321
322    def postprocess(self, output: np.ndarray, scale: Tuple[float, float]) -> List[Dict]:
323        """์ถœ๋ ฅ ํ›„์ฒ˜๋ฆฌ"""
324        if not HAS_CV2:
325            return []
326
327        predictions = output[0]
328
329        boxes = []
330        scores = []
331        class_ids = []
332
333        for pred in predictions:
334            confidence = pred[4]
335
336            if confidence > self.conf_threshold:
337                class_probs = pred[5:]
338                class_id = np.argmax(class_probs)
339                class_score = class_probs[class_id]
340
341                if class_score > self.conf_threshold:
342                    # ๋ฐ•์Šค ์ขŒํ‘œ (center_x, center_y, width, height)
343                    cx, cy, w, h = pred[:4]
344
345                    # ์›๋ณธ ์Šค์ผ€์ผ๋กœ ๋ณ€ํ™˜
346                    x1 = int((cx - w / 2) * scale[0])
347                    y1 = int((cy - h / 2) * scale[1])
348                    x2 = int((cx + w / 2) * scale[0])
349                    y2 = int((cy + h / 2) * scale[1])
350
351                    boxes.append([x1, y1, x2, y2])
352                    scores.append(float(confidence * class_score))
353                    class_ids.append(int(class_id))
354
355        # NMS (Non-Maximum Suppression)
356        if boxes:
357            indices = cv2.dnn.NMSBoxes(
358                boxes, scores, self.conf_threshold, self.iou_threshold
359            )
360
361            results = []
362            for i in indices:
363                idx = i[0] if isinstance(i, (list, np.ndarray)) else i
364                results.append({
365                    'box': boxes[idx],
366                    'score': scores[idx],
367                    'class_id': class_ids[idx],
368                    'class_name': self.COCO_CLASSES[class_ids[idx]]
369                })
370
371            return results
372
373        return []
374
375    def _simulate_detection(self, image: np.ndarray) -> List[Dict]:
376        """์‹œ๋ฎฌ๋ ˆ์ด์…˜: ๋žœ๋ค ๊ฒ€์ถœ ๊ฒฐ๊ณผ"""
377        height, width = image.shape[:2]
378
379        num_detections = np.random.randint(1, 5)
380        detections = []
381
382        for _ in range(num_detections):
383            x1 = np.random.randint(0, width // 2)
384            y1 = np.random.randint(0, height // 2)
385            x2 = np.random.randint(x1 + 50, width)
386            y2 = np.random.randint(y1 + 50, height)
387
388            class_id = np.random.randint(0, len(self.COCO_CLASSES))
389
390            detections.append({
391                'box': [x1, y1, x2, y2],
392                'score': np.random.uniform(0.5, 0.95),
393                'class_id': class_id,
394                'class_name': self.COCO_CLASSES[class_id]
395            })
396
397        return detections
398
399    def draw_detections(self, image: np.ndarray, detections: List[Dict]) -> np.ndarray:
400        """๊ฒ€์ถœ ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”"""
401        if not HAS_CV2:
402            print("๊ฒฝ๊ณ : opencv-python์ด ์—†์–ด ์‹œ๊ฐํ™”๋ฅผ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.")
403            return image
404
405        result = image.copy()
406
407        for det in detections:
408            x1, y1, x2, y2 = det['box']
409            label = f"{det['class_name']}: {det['score']:.2f}"
410
411            # ๋ฐ•์Šค ๊ทธ๋ฆฌ๊ธฐ
412            cv2.rectangle(result, (x1, y1), (x2, y2), (0, 255, 0), 2)
413
414            # ๋ผ๋ฒจ ๋ฐฐ๊ฒฝ
415            (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
416            cv2.rectangle(result, (x1, y1 - 20), (x1 + w, y1), (0, 255, 0), -1)
417
418            # ๋ผ๋ฒจ ํ…์ŠคํŠธ
419            cv2.putText(result, label, (x1, y1 - 5),
420                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
421
422        return result
423
424
425# === ์‚ฌ์šฉ ์˜ˆ์ œ ===
426
427def example_basic_inference():
428    """๊ธฐ๋ณธ ONNX ์ถ”๋ก  ์˜ˆ์ œ"""
429    print("\n=== ๊ธฐ๋ณธ ONNX ์ถ”๋ก  ์˜ˆ์ œ ===")
430
431    if not HAS_ONNX:
432        print("onnxruntime์ด ์„ค์น˜๋˜์ง€ ์•Š์•„ ์˜ˆ์ œ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
433        return
434
435    # ์‹œ๋ฎฌ๋ ˆ์ด์…˜: ๋”๋ฏธ ๋ชจ๋ธ ์ƒ์„ฑ
436    print("์‹œ๋ฎฌ๋ ˆ์ด์…˜ ๋ชจ๋“œ: ๋”๋ฏธ ์ž…๋ ฅ์œผ๋กœ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.")
437
438    # ๋”๋ฏธ ๋ฐ์ดํ„ฐ
439    batch_size = 1
440    channels = 3
441    height = 224
442    width = 224
443
444    dummy_input = np.random.randn(batch_size, channels, height, width).astype(np.float32)
445
446    print(f"์ž…๋ ฅ ํ˜•ํƒœ: {dummy_input.shape}")
447    print(f"์ž…๋ ฅ ๋ฐ์ดํ„ฐ ๋ฒ”์œ„: [{dummy_input.min():.2f}, {dummy_input.max():.2f}]")
448
449
450def example_image_classification():
451    """์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์˜ˆ์ œ"""
452    print("\n=== ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์˜ˆ์ œ ===")
453
454    if not HAS_ONNX:
455        print("onnxruntime์ด ์„ค์น˜๋˜์ง€ ์•Š์•„ ์˜ˆ์ œ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
456        return
457
458    # ๋ชจ๋ธ ๊ฒฝ๋กœ (์˜ˆ์‹œ)
459    model_path = "resnet18.onnx"
460
461    if not os.path.exists(model_path):
462        print(f"๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {model_path}")
463        print("PyTorch์—์„œ ๋ณ€ํ™˜ ์˜ˆ์‹œ:")
464        print("  import torch")
465        print("  model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)")
466        print("  dummy_input = torch.randn(1, 3, 224, 224)")
467        print("  torch.onnx.export(model, dummy_input, 'resnet18.onnx')")
468        return
469
470    # ๋ถ„๋ฅ˜๊ธฐ ์ƒ์„ฑ
471    classifier = ImageClassifier(model_path)
472
473    # ๋ฒค์น˜๋งˆํฌ
474    print("\n์„ฑ๋Šฅ ๋ฒค์น˜๋งˆํฌ:")
475    results = classifier.benchmark(num_iterations=50)
476    print(f"  ํ‰๊ท : {results['mean_ms']:.2f}ms")
477    print(f"  FPS: {results['fps']:.1f}")
478
479
480def example_object_detection():
481    """๊ฐ์ฒด ๊ฒ€์ถœ ์˜ˆ์ œ"""
482    print("\n=== ๊ฐ์ฒด ๊ฒ€์ถœ ์˜ˆ์ œ (์‹œ๋ฎฌ๋ ˆ์ด์…˜) ===")
483
484    # ์‹œ๋ฎฌ๋ ˆ์ด์…˜ ๋ชจ๋“œ๋กœ ์‹คํ–‰
485    detector = YOLODetector("yolov5s.onnx")  # ํŒŒ์ผ์ด ์—†์–ด๋„ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ ๊ฐ€๋Šฅ
486
487    # ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
488    dummy_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
489
490    # ๊ฒ€์ถœ
491    detections = detector.detect(dummy_image)
492
493    # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
494    print("\n๊ฒ€์ถœ ๊ฒฐ๊ณผ:")
495    for i, det in enumerate(detections):
496        print(f"  {i+1}. {det['class_name']}: {det['score']:.2f}")
497        print(f"     ๋ฐ•์Šค: {det['box']}")
498
499    # ์‹œ๊ฐํ™” (OpenCV๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ)
500    if HAS_CV2:
501        result_image = detector.draw_detections(dummy_image, detections)
502        print("\n๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์™„๋ฃŒ")
503
504
505def example_performance_comparison():
506    """์„ฑ๋Šฅ ๋น„๊ต ์˜ˆ์ œ"""
507    print("\n=== ์„ฑ๋Šฅ ๋น„๊ต ์˜ˆ์ œ ===")
508
509    if not HAS_ONNX:
510        print("onnxruntime์ด ์„ค์น˜๋˜์ง€ ์•Š์•„ ์˜ˆ์ œ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
511        return
512
513    print("๋ฐฐ์น˜ ํฌ๊ธฐ๋ณ„ ์„ฑ๋Šฅ ๋น„๊ต (์‹œ๋ฎฌ๋ ˆ์ด์…˜)")
514
515    input_shape = (1, 3, 224, 224)
516
517    for batch_size in [1, 4, 8, 16]:
518        data = np.random.randn(batch_size, *input_shape[1:]).astype(np.float32)
519
520        start = time.perf_counter()
521        # ์‹œ๋ฎฌ๋ ˆ์ด์…˜: ๊ฐ„๋‹จํ•œ ์—ฐ์‚ฐ
522        _ = np.mean(data, axis=(2, 3))
523        elapsed = time.perf_counter() - start
524
525        throughput = batch_size / elapsed
526        print(f"๋ฐฐ์น˜ ํฌ๊ธฐ {batch_size:2d}: {throughput:.1f} samples/sec")
527
528
529# === ๋ฉ”์ธ ์‹คํ–‰ ===
530
531if __name__ == "__main__":
532    print("=" * 60)
533    print("ONNX Runtime Edge AI ์ถ”๋ก  ์˜ˆ์ œ")
534    print("=" * 60)
535
536    # ONNX Runtime ์„ค์น˜ ํ™•์ธ
537    if HAS_ONNX:
538        print(f"\nONNX Runtime ๋ฒ„์ „: {ort.__version__}")
539        print(f"์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ”„๋กœ๋ฐ”์ด๋”: {ort.get_available_providers()}")
540    else:
541        print("\n๊ฒฝ๊ณ : ONNX Runtime์ด ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
542        print("์„ค์น˜: pip install onnxruntime")
543
544    # ์˜ˆ์ œ ์‹คํ–‰
545    example_basic_inference()
546    example_image_classification()
547    example_object_detection()
548    example_performance_comparison()
549
550    print("\n" + "=" * 60)
551    print("๋ชจ๋“  ์˜ˆ์ œ ์™„๋ฃŒ")
552    print("=" * 60)