1#!/usr/bin/env python3
2"""
3TensorFlow Lite Inference Example
4TFLite μΆλ‘ μμ
5
6Performs image classification using a TFLite model.
7
8Install:
9 pip install tflite-runtime numpy pillow
10
11Usage:
12 python3 tflite_inference.py --model model.tflite --image image.jpg
13 python3 tflite_inference.py --model model.tflite --image image.jpg --labels labels.txt
14"""
15
16import numpy as np
17from PIL import Image
18import time
19import argparse
20import os
21
22# Try to import TFLite runtime
23try:
24 from tflite_runtime.interpreter import Interpreter
25except ImportError:
26 try:
27 from tensorflow.lite.python.interpreter import Interpreter
28 except ImportError:
29 print("Error: tflite-runtime or tensorflow not found")
30 print("Install with: pip install tflite-runtime")
31 exit(1)
32
33class TFLiteClassifier:
34 """TensorFlow Lite Image Classifier"""
35
36 def __init__(self, model_path: str, labels_path: str = None):
37 """Initialize classifier with model and optional labels"""
38 if not os.path.exists(model_path):
39 raise FileNotFoundError(f"Model not found: {model_path}")
40
41 # Load model
42 print(f"Loading model: {model_path}")
43 self.interpreter = Interpreter(model_path=model_path)
44 self.interpreter.allocate_tensors()
45
46 # Get input/output details
47 self.input_details = self.interpreter.get_input_details()
48 self.output_details = self.interpreter.get_output_details()
49
50 # Input shape
51 self.input_shape = self.input_details[0]['shape']
52 self.input_height = self.input_shape[1]
53 self.input_width = self.input_shape[2]
54 self.input_dtype = self.input_details[0]['dtype']
55
56 print(f"Input shape: {self.input_shape}")
57 print(f"Input dtype: {self.input_dtype}")
58
59 # Load labels
60 self.labels = []
61 if labels_path and os.path.exists(labels_path):
62 with open(labels_path, 'r') as f:
63 self.labels = [line.strip() for line in f.readlines()]
64 print(f"Loaded {len(self.labels)} labels")
65
66 def preprocess(self, image_path: str) -> np.ndarray:
67 """Preprocess image for inference"""
68 # Load and resize image
69 image = Image.open(image_path).convert('RGB')
70 image = image.resize((self.input_width, self.input_height))
71
72 # Convert to numpy array
73 input_data = np.array(image, dtype=np.float32)
74
75 # Normalize (MobileNet style: -1 to 1)
76 input_data = (input_data - 127.5) / 127.5
77
78 # Add batch dimension
79 input_data = np.expand_dims(input_data, axis=0)
80
81 return input_data
82
83 def classify(self, image_path: str, top_k: int = 5) -> dict:
84 """Classify image and return top-k predictions"""
85 if not os.path.exists(image_path):
86 raise FileNotFoundError(f"Image not found: {image_path}")
87
88 # Preprocess
89 input_data = self.preprocess(image_path)
90
91 # Inference
92 start_time = time.perf_counter()
93
94 self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
95 self.interpreter.invoke()
96 output = self.interpreter.get_tensor(self.output_details[0]['index'])[0]
97
98 inference_time = (time.perf_counter() - start_time) * 1000
99
100 # Get top-k predictions
101 top_indices = output.argsort()[-top_k:][::-1]
102
103 predictions = []
104 for idx in top_indices:
105 label = self.labels[idx] if idx < len(self.labels) else f"class_{idx}"
106 score = float(output[idx])
107 predictions.append({
108 "class_id": int(idx),
109 "label": label,
110 "score": score,
111 "confidence": f"{score * 100:.1f}%"
112 })
113
114 return {
115 "image": image_path,
116 "predictions": predictions,
117 "inference_time_ms": round(inference_time, 2),
118 "model_input_size": f"{self.input_width}x{self.input_height}"
119 }
120
121 def benchmark(self, num_runs: int = 100) -> dict:
122 """Benchmark inference speed"""
123 print(f"\nBenchmarking ({num_runs} runs)...")
124
125 # Create dummy input
126 dummy_input = np.random.rand(*self.input_shape).astype(np.float32)
127
128 # Warmup
129 for _ in range(10):
130 self.interpreter.set_tensor(self.input_details[0]['index'], dummy_input)
131 self.interpreter.invoke()
132
133 # Benchmark
134 times = []
135 for _ in range(num_runs):
136 start = time.perf_counter()
137 self.interpreter.set_tensor(self.input_details[0]['index'], dummy_input)
138 self.interpreter.invoke()
139 _ = self.interpreter.get_tensor(self.output_details[0]['index'])
140 times.append((time.perf_counter() - start) * 1000)
141
142 avg_time = np.mean(times)
143 std_time = np.std(times)
144 fps = 1000 / avg_time
145
146 return {
147 "runs": num_runs,
148 "avg_time_ms": round(avg_time, 2),
149 "std_time_ms": round(std_time, 2),
150 "fps": round(fps, 1),
151 "min_time_ms": round(min(times), 2),
152 "max_time_ms": round(max(times), 2)
153 }
154
155def create_dummy_model():
156 """Create a dummy TFLite model for testing"""
157 try:
158 import tensorflow as tf
159
160 # Simple model
161 model = tf.keras.Sequential([
162 tf.keras.layers.Input(shape=(224, 224, 3)),
163 tf.keras.layers.Conv2D(16, 3, activation='relu'),
164 tf.keras.layers.GlobalAveragePooling2D(),
165 tf.keras.layers.Dense(10, activation='softmax')
166 ])
167
168 # Convert to TFLite
169 converter = tf.lite.TFLiteConverter.from_keras_model(model)
170 tflite_model = converter.convert()
171
172 # Save
173 with open('dummy_model.tflite', 'wb') as f:
174 f.write(tflite_model)
175
176 print("Created dummy_model.tflite for testing")
177 return 'dummy_model.tflite'
178
179 except ImportError:
180 print("TensorFlow not available. Cannot create dummy model.")
181 return None
182
183def main():
184 """Main function"""
185 parser = argparse.ArgumentParser(description="TFLite Image Classifier")
186 parser.add_argument("--model", required=True, help="Path to TFLite model")
187 parser.add_argument("--image", help="Path to image file")
188 parser.add_argument("--labels", help="Path to labels file")
189 parser.add_argument("--top-k", type=int, default=5, help="Number of top predictions")
190 parser.add_argument("--benchmark", action="store_true", help="Run benchmark")
191 args = parser.parse_args()
192
193 print("=== TFLite Inference ===\n")
194
195 try:
196 classifier = TFLiteClassifier(args.model, args.labels)
197
198 if args.benchmark:
199 # Run benchmark
200 results = classifier.benchmark()
201 print("\nBenchmark Results:")
202 print(f" Average time: {results['avg_time_ms']:.2f} ms (+/- {results['std_time_ms']:.2f})")
203 print(f" FPS: {results['fps']:.1f}")
204 print(f" Min/Max: {results['min_time_ms']:.2f} / {results['max_time_ms']:.2f} ms")
205
206 elif args.image:
207 # Classify image
208 results = classifier.classify(args.image, args.top_k)
209
210 print(f"\nImage: {results['image']}")
211 print(f"Inference time: {results['inference_time_ms']} ms")
212 print(f"\nTop-{args.top_k} Predictions:")
213
214 for i, pred in enumerate(results['predictions'], 1):
215 print(f" {i}. {pred['label']}: {pred['confidence']}")
216
217 else:
218 print("Please specify --image or --benchmark")
219 print("\nModel info:")
220 print(f" Input shape: {classifier.input_shape}")
221 print(f" Input dtype: {classifier.input_dtype}")
222
223 except FileNotFoundError as e:
224 print(f"Error: {e}")
225 except Exception as e:
226 print(f"Error: {e}")
227
228if __name__ == "__main__":
229 main()