tflite_inference.py

Download
python 230 lines 7.7 KB
  1#!/usr/bin/env python3
  2"""
  3TensorFlow Lite Inference Example
  4TFLite μΆ”λ‘  예제
  5
  6Performs image classification using a TFLite model.
  7
  8Install:
  9    pip install tflite-runtime numpy pillow
 10
 11Usage:
 12    python3 tflite_inference.py --model model.tflite --image image.jpg
 13    python3 tflite_inference.py --model model.tflite --image image.jpg --labels labels.txt
 14"""
 15
 16import numpy as np
 17from PIL import Image
 18import time
 19import argparse
 20import os
 21
 22# Try to import TFLite runtime
 23try:
 24    from tflite_runtime.interpreter import Interpreter
 25except ImportError:
 26    try:
 27        from tensorflow.lite.python.interpreter import Interpreter
 28    except ImportError:
 29        print("Error: tflite-runtime or tensorflow not found")
 30        print("Install with: pip install tflite-runtime")
 31        exit(1)
 32
 33class TFLiteClassifier:
 34    """TensorFlow Lite Image Classifier"""
 35
 36    def __init__(self, model_path: str, labels_path: str = None):
 37        """Initialize classifier with model and optional labels"""
 38        if not os.path.exists(model_path):
 39            raise FileNotFoundError(f"Model not found: {model_path}")
 40
 41        # Load model
 42        print(f"Loading model: {model_path}")
 43        self.interpreter = Interpreter(model_path=model_path)
 44        self.interpreter.allocate_tensors()
 45
 46        # Get input/output details
 47        self.input_details = self.interpreter.get_input_details()
 48        self.output_details = self.interpreter.get_output_details()
 49
 50        # Input shape
 51        self.input_shape = self.input_details[0]['shape']
 52        self.input_height = self.input_shape[1]
 53        self.input_width = self.input_shape[2]
 54        self.input_dtype = self.input_details[0]['dtype']
 55
 56        print(f"Input shape: {self.input_shape}")
 57        print(f"Input dtype: {self.input_dtype}")
 58
 59        # Load labels
 60        self.labels = []
 61        if labels_path and os.path.exists(labels_path):
 62            with open(labels_path, 'r') as f:
 63                self.labels = [line.strip() for line in f.readlines()]
 64            print(f"Loaded {len(self.labels)} labels")
 65
 66    def preprocess(self, image_path: str) -> np.ndarray:
 67        """Preprocess image for inference"""
 68        # Load and resize image
 69        image = Image.open(image_path).convert('RGB')
 70        image = image.resize((self.input_width, self.input_height))
 71
 72        # Convert to numpy array
 73        input_data = np.array(image, dtype=np.float32)
 74
 75        # Normalize (MobileNet style: -1 to 1)
 76        input_data = (input_data - 127.5) / 127.5
 77
 78        # Add batch dimension
 79        input_data = np.expand_dims(input_data, axis=0)
 80
 81        return input_data
 82
 83    def classify(self, image_path: str, top_k: int = 5) -> dict:
 84        """Classify image and return top-k predictions"""
 85        if not os.path.exists(image_path):
 86            raise FileNotFoundError(f"Image not found: {image_path}")
 87
 88        # Preprocess
 89        input_data = self.preprocess(image_path)
 90
 91        # Inference
 92        start_time = time.perf_counter()
 93
 94        self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
 95        self.interpreter.invoke()
 96        output = self.interpreter.get_tensor(self.output_details[0]['index'])[0]
 97
 98        inference_time = (time.perf_counter() - start_time) * 1000
 99
100        # Get top-k predictions
101        top_indices = output.argsort()[-top_k:][::-1]
102
103        predictions = []
104        for idx in top_indices:
105            label = self.labels[idx] if idx < len(self.labels) else f"class_{idx}"
106            score = float(output[idx])
107            predictions.append({
108                "class_id": int(idx),
109                "label": label,
110                "score": score,
111                "confidence": f"{score * 100:.1f}%"
112            })
113
114        return {
115            "image": image_path,
116            "predictions": predictions,
117            "inference_time_ms": round(inference_time, 2),
118            "model_input_size": f"{self.input_width}x{self.input_height}"
119        }
120
121    def benchmark(self, num_runs: int = 100) -> dict:
122        """Benchmark inference speed"""
123        print(f"\nBenchmarking ({num_runs} runs)...")
124
125        # Create dummy input
126        dummy_input = np.random.rand(*self.input_shape).astype(np.float32)
127
128        # Warmup
129        for _ in range(10):
130            self.interpreter.set_tensor(self.input_details[0]['index'], dummy_input)
131            self.interpreter.invoke()
132
133        # Benchmark
134        times = []
135        for _ in range(num_runs):
136            start = time.perf_counter()
137            self.interpreter.set_tensor(self.input_details[0]['index'], dummy_input)
138            self.interpreter.invoke()
139            _ = self.interpreter.get_tensor(self.output_details[0]['index'])
140            times.append((time.perf_counter() - start) * 1000)
141
142        avg_time = np.mean(times)
143        std_time = np.std(times)
144        fps = 1000 / avg_time
145
146        return {
147            "runs": num_runs,
148            "avg_time_ms": round(avg_time, 2),
149            "std_time_ms": round(std_time, 2),
150            "fps": round(fps, 1),
151            "min_time_ms": round(min(times), 2),
152            "max_time_ms": round(max(times), 2)
153        }
154
155def create_dummy_model():
156    """Create a dummy TFLite model for testing"""
157    try:
158        import tensorflow as tf
159
160        # Simple model
161        model = tf.keras.Sequential([
162            tf.keras.layers.Input(shape=(224, 224, 3)),
163            tf.keras.layers.Conv2D(16, 3, activation='relu'),
164            tf.keras.layers.GlobalAveragePooling2D(),
165            tf.keras.layers.Dense(10, activation='softmax')
166        ])
167
168        # Convert to TFLite
169        converter = tf.lite.TFLiteConverter.from_keras_model(model)
170        tflite_model = converter.convert()
171
172        # Save
173        with open('dummy_model.tflite', 'wb') as f:
174            f.write(tflite_model)
175
176        print("Created dummy_model.tflite for testing")
177        return 'dummy_model.tflite'
178
179    except ImportError:
180        print("TensorFlow not available. Cannot create dummy model.")
181        return None
182
183def main():
184    """Main function"""
185    parser = argparse.ArgumentParser(description="TFLite Image Classifier")
186    parser.add_argument("--model", required=True, help="Path to TFLite model")
187    parser.add_argument("--image", help="Path to image file")
188    parser.add_argument("--labels", help="Path to labels file")
189    parser.add_argument("--top-k", type=int, default=5, help="Number of top predictions")
190    parser.add_argument("--benchmark", action="store_true", help="Run benchmark")
191    args = parser.parse_args()
192
193    print("=== TFLite Inference ===\n")
194
195    try:
196        classifier = TFLiteClassifier(args.model, args.labels)
197
198        if args.benchmark:
199            # Run benchmark
200            results = classifier.benchmark()
201            print("\nBenchmark Results:")
202            print(f"  Average time: {results['avg_time_ms']:.2f} ms (+/- {results['std_time_ms']:.2f})")
203            print(f"  FPS: {results['fps']:.1f}")
204            print(f"  Min/Max: {results['min_time_ms']:.2f} / {results['max_time_ms']:.2f} ms")
205
206        elif args.image:
207            # Classify image
208            results = classifier.classify(args.image, args.top_k)
209
210            print(f"\nImage: {results['image']}")
211            print(f"Inference time: {results['inference_time_ms']} ms")
212            print(f"\nTop-{args.top_k} Predictions:")
213
214            for i, pred in enumerate(results['predictions'], 1):
215                print(f"  {i}. {pred['label']}: {pred['confidence']}")
216
217        else:
218            print("Please specify --image or --benchmark")
219            print("\nModel info:")
220            print(f"  Input shape: {classifier.input_shape}")
221            print(f"  Input dtype: {classifier.input_dtype}")
222
223    except FileNotFoundError as e:
224        print(f"Error: {e}")
225    except Exception as e:
226        print(f"Error: {e}")
227
228if __name__ == "__main__":
229    main()