22_object_tracking.py

Download
python 540 lines 16.2 KB
  1"""
  2Real-time Object Tracking with OpenCV
  3
  4This script demonstrates multiple object tracking algorithms:
  5- KCF (Kernelized Correlation Filters)
  6- CSRT (Discriminative Correlation Filter with Channel and Spatial Reliability)
  7- MOSSE (Minimum Output Sum of Squared Error)
  8- Centroid-based tracking (from scratch)
  9- Kalman filter for motion prediction
 10
 11Requirements:
 12    pip install opencv-contrib-python numpy
 13
 14Author: Claude
 15Date: 2026-02-15
 16"""
 17
 18import cv2
 19import numpy as np
 20from typing import List, Tuple, Dict, Optional
 21import sys
 22
 23
 24class CentroidTracker:
 25    """
 26    Simple centroid-based object tracker.
 27
 28    Tracks objects by computing centroids and matching them across frames
 29    using Euclidean distance.
 30    """
 31
 32    def __init__(self, max_disappeared: int = 50):
 33        """
 34        Args:
 35            max_disappeared: Maximum frames an object can disappear before removal
 36        """
 37        self.next_object_id = 0
 38        self.objects: Dict[int, Tuple[int, int]] = {}
 39        self.disappeared = {}
 40        self.max_disappeared = max_disappeared
 41
 42    def register(self, centroid: Tuple[int, int]) -> int:
 43        """Register a new object with its centroid."""
 44        object_id = self.next_object_id
 45        self.objects[object_id] = centroid
 46        self.disappeared[object_id] = 0
 47        self.next_object_id += 1
 48        return object_id
 49
 50    def deregister(self, object_id: int):
 51        """Remove an object from tracking."""
 52        del self.objects[object_id]
 53        del self.disappeared[object_id]
 54
 55    def update(self, rects: List[Tuple[int, int, int, int]]) -> Dict[int, Tuple[int, int]]:
 56        """
 57        Update tracker with new detections.
 58
 59        Args:
 60            rects: List of bounding boxes (x, y, w, h)
 61
 62        Returns:
 63            Dictionary mapping object IDs to centroids
 64        """
 65        # If no detections, mark all objects as disappeared
 66        if len(rects) == 0:
 67            for object_id in list(self.disappeared.keys()):
 68                self.disappeared[object_id] += 1
 69                if self.disappeared[object_id] > self.max_disappeared:
 70                    self.deregister(object_id)
 71            return self.objects
 72
 73        # Compute centroids of new detections
 74        input_centroids = []
 75        for (x, y, w, h) in rects:
 76            cx = int(x + w / 2)
 77            cy = int(y + h / 2)
 78            input_centroids.append((cx, cy))
 79
 80        # If no existing objects, register all new centroids
 81        if len(self.objects) == 0:
 82            for centroid in input_centroids:
 83                self.register(centroid)
 84        else:
 85            # Match existing objects to new centroids
 86            object_ids = list(self.objects.keys())
 87            object_centroids = list(self.objects.values())
 88
 89            # Compute distance matrix
 90            D = np.zeros((len(object_centroids), len(input_centroids)))
 91            for i, obj_centroid in enumerate(object_centroids):
 92                for j, input_centroid in enumerate(input_centroids):
 93                    D[i, j] = np.linalg.norm(
 94                        np.array(obj_centroid) - np.array(input_centroid)
 95                    )
 96
 97            # Find minimum distance assignments
 98            rows = D.min(axis=1).argsort()
 99            cols = D.argmin(axis=1)[rows]
100
101            used_rows = set()
102            used_cols = set()
103
104            # Update matched objects
105            for (row, col) in zip(rows, cols):
106                if row in used_rows or col in used_cols:
107                    continue
108
109                object_id = object_ids[row]
110                self.objects[object_id] = input_centroids[col]
111                self.disappeared[object_id] = 0
112
113                used_rows.add(row)
114                used_cols.add(col)
115
116            # Handle unmatched objects
117            unused_rows = set(range(D.shape[0])) - used_rows
118            for row in unused_rows:
119                object_id = object_ids[row]
120                self.disappeared[object_id] += 1
121                if self.disappeared[object_id] > self.max_disappeared:
122                    self.deregister(object_id)
123
124            # Register new objects
125            unused_cols = set(range(D.shape[1])) - used_cols
126            for col in unused_cols:
127                self.register(input_centroids[col])
128
129        return self.objects
130
131
132class KalmanTracker:
133    """Kalman filter-based tracker for motion prediction."""
134
135    def __init__(self):
136        """Initialize Kalman filter for 2D position tracking."""
137        # State: [x, y, vx, vy] (position and velocity)
138        self.kalman = cv2.KalmanFilter(4, 2)
139        self.kalman.measurementMatrix = np.array([
140            [1, 0, 0, 0],
141            [0, 1, 0, 0]
142        ], np.float32)
143
144        self.kalman.transitionMatrix = np.array([
145            [1, 0, 1, 0],
146            [0, 1, 0, 1],
147            [0, 0, 1, 0],
148            [0, 0, 0, 1]
149        ], np.float32)
150
151        self.kalman.processNoiseCov = np.eye(4, dtype=np.float32) * 0.03
152        self.kalman.measurementNoiseCov = np.eye(2, dtype=np.float32) * 0.1
153
154    def predict(self) -> Tuple[int, int]:
155        """Predict next position."""
156        prediction = self.kalman.predict()
157        return int(prediction[0]), int(prediction[1])
158
159    def update(self, x: int, y: int) -> Tuple[int, int]:
160        """Update with measurement and return corrected position."""
161        measurement = np.array([[np.float32(x)], [np.float32(y)]])
162        self.kalman.correct(measurement)
163        return x, y
164
165
166def compute_iou(box1: Tuple[int, int, int, int],
167                box2: Tuple[int, int, int, int]) -> float:
168    """
169    Compute Intersection over Union (IoU) between two bounding boxes.
170
171    Args:
172        box1, box2: Bounding boxes in format (x, y, w, h)
173
174    Returns:
175        IoU value between 0 and 1
176    """
177    x1, y1, w1, h1 = box1
178    x2, y2, w2, h2 = box2
179
180    # Compute intersection
181    xi1 = max(x1, x2)
182    yi1 = max(y1, y2)
183    xi2 = min(x1 + w1, x2 + w2)
184    yi2 = min(y1 + h1, y2 + h2)
185
186    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
187
188    # Compute union
189    box1_area = w1 * h1
190    box2_area = w2 * h2
191    union_area = box1_area + box2_area - inter_area
192
193    return inter_area / union_area if union_area > 0 else 0.0
194
195
196def generate_synthetic_video(output_path: str,
197                            num_frames: int = 300,
198                            width: int = 640,
199                            height: int = 480) -> str:
200    """
201    Generate a synthetic video with moving objects for testing.
202
203    Args:
204        output_path: Path to save the video
205        num_frames: Number of frames to generate
206        width, height: Video dimensions
207
208    Returns:
209        Path to the generated video
210    """
211    fourcc = cv2.VideoWriter_fourcc(*'XVID')
212    out = cv2.VideoWriter(output_path, fourcc, 20.0, (width, height))
213
214    # Define moving objects (x, y, vx, vy, size, color)
215    objects = [
216        [50, 50, 2, 1, 30, (0, 255, 0)],      # Green square
217        [300, 100, -1, 2, 25, (255, 0, 0)],    # Blue square
218        [100, 300, 1.5, -1, 35, (0, 0, 255)],  # Red square
219    ]
220
221    for frame_idx in range(num_frames):
222        # Create blank frame
223        frame = np.zeros((height, width, 3), dtype=np.uint8)
224
225        # Update and draw objects
226        for obj in objects:
227            x, y, vx, vy, size, color = obj
228
229            # Draw object
230            cv2.rectangle(frame,
231                         (int(x), int(y)),
232                         (int(x + size), int(y + size)),
233                         color, -1)
234
235            # Update position
236            obj[0] += vx
237            obj[1] += vy
238
239            # Bounce off walls
240            if obj[0] <= 0 or obj[0] >= width - size:
241                obj[2] *= -1
242            if obj[1] <= 0 or obj[1] >= height - size:
243                obj[3] *= -1
244
245        # Add frame number
246        cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
247                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
248
249        out.write(frame)
250
251    out.release()
252    print(f"Generated synthetic video: {output_path}")
253    return output_path
254
255
256def demo_opencv_trackers(video_path: Optional[str] = None):
257    """
258    Demonstrate OpenCV's built-in tracking algorithms.
259
260    Args:
261        video_path: Path to input video (None for synthetic video)
262    """
263    # Generate synthetic video if no input provided
264    if video_path is None:
265        video_path = '/tmp/tracking_demo.avi'
266        generate_synthetic_video(video_path)
267
268    # Available tracker types
269    tracker_types = {
270        'KCF': cv2.TrackerKCF_create,
271        'CSRT': cv2.TrackerCSRT_create,
272    }
273
274    # Try to add MOSSE tracker if available (in legacy module)
275    try:
276        tracker_types['MOSSE'] = cv2.legacy.TrackerMOSSE_create
277    except AttributeError:
278        print("MOSSE tracker not available (requires opencv-contrib-python)")
279
280    cap = cv2.VideoCapture(video_path)
281    if not cap.isOpened():
282        print(f"Error: Cannot open video {video_path}")
283        return
284
285    # Read first frame
286    ret, frame = cap.read()
287    if not ret:
288        print("Error: Cannot read first frame")
289        return
290
291    # Select ROI for tracking
292    print("Select ROI to track, then press SPACE or ENTER")
293    bbox = cv2.selectROI("Select ROI", frame, False)
294    cv2.destroyWindow("Select ROI")
295
296    if bbox[2] == 0 or bbox[3] == 0:
297        print("No ROI selected, using default")
298        bbox = (100, 100, 60, 60)
299
300    # Initialize trackers
301    trackers = {}
302    for name, create_fn in tracker_types.items():
303        tracker = create_fn()
304        tracker.init(frame, bbox)
305        trackers[name] = tracker
306
307    # Initialize Kalman filter
308    kalman = KalmanTracker()
309    cx, cy = int(bbox[0] + bbox[2]/2), int(bbox[1] + bbox[3]/2)
310    kalman.update(cx, cy)
311
312    frame_count = 0
313
314    while True:
315        ret, frame = cap.read()
316        if not ret:
317            break
318
319        frame_count += 1
320        display = frame.copy()
321
322        # Track with each algorithm
323        y_offset = 30
324        for name, tracker in trackers.items():
325            success, box = tracker.update(frame)
326
327            if success:
328                x, y, w, h = [int(v) for v in box]
329                color = (0, 255, 0) if name == 'CSRT' else (255, 0, 0) if name == 'KCF' else (0, 165, 255)
330                cv2.rectangle(display, (x, y), (x+w, y+h), color, 2)
331                cv2.putText(display, f'{name}', (x, y-10),
332                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
333
334                # Update Kalman filter with CSRT results (most accurate)
335                if name == 'CSRT':
336                    cx, cy = int(x + w/2), int(y + h/2)
337                    kalman.update(cx, cy)
338            else:
339                cv2.putText(display, f'{name}: Lost', (10, y_offset),
340                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
341
342            y_offset += 25
343
344        # Show Kalman prediction
345        pred_x, pred_y = kalman.predict()
346        cv2.circle(display, (pred_x, pred_y), 5, (255, 255, 0), -1)
347        cv2.putText(display, 'Kalman', (pred_x+10, pred_y),
348                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 2)
349
350        cv2.putText(display, f'Frame: {frame_count}', (10, frame.shape[0]-10),
351                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
352
353        cv2.imshow('Object Tracking', display)
354
355        key = cv2.waitKey(30) & 0xFF
356        if key == ord('q'):
357            break
358        elif key == ord('p'):
359            cv2.waitKey(0)  # Pause
360
361    cap.release()
362    cv2.destroyAllWindows()
363
364
365def demo_centroid_tracking():
366    """Demonstrate centroid-based tracking with background subtraction."""
367    # Generate synthetic video
368    video_path = '/tmp/centroid_tracking_demo.avi'
369    generate_synthetic_video(video_path, num_frames=200)
370
371    cap = cv2.VideoCapture(video_path)
372    tracker = CentroidTracker(max_disappeared=30)
373
374    # Background subtractor for detecting moving objects
375    bg_subtractor = cv2.createBackgroundSubtractorMOG2(
376        history=100, varThreshold=40, detectShadows=True
377    )
378
379    colors = [
380        (0, 255, 0), (255, 0, 0), (0, 0, 255),
381        (255, 255, 0), (255, 0, 255), (0, 255, 255)
382    ]
383
384    frame_count = 0
385
386    while True:
387        ret, frame = cap.read()
388        if not ret:
389            break
390
391        frame_count += 1
392
393        # Apply background subtraction
394        fg_mask = bg_subtractor.apply(frame)
395
396        # Clean up the mask
397        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
398        fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel)
399        fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_OPEN, kernel)
400
401        # Find contours
402        contours, _ = cv2.findContours(fg_mask, cv2.RETR_EXTERNAL,
403                                       cv2.CHAIN_APPROX_SIMPLE)
404
405        # Extract bounding boxes
406        rects = []
407        for contour in contours:
408            if cv2.contourArea(contour) < 500:  # Filter small objects
409                continue
410            x, y, w, h = cv2.boundingRect(contour)
411            rects.append((x, y, w, h))
412
413        # Update tracker
414        objects = tracker.update(rects)
415
416        # Draw tracked objects
417        for object_id, centroid in objects.items():
418            color = colors[object_id % len(colors)]
419            cv2.circle(frame, centroid, 5, color, -1)
420            cv2.putText(frame, f'ID: {object_id}',
421                       (centroid[0] - 10, centroid[1] - 10),
422                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
423
424        # Draw bounding boxes
425        for (x, y, w, h) in rects:
426            cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 255), 2)
427
428        cv2.putText(frame, f'Frame: {frame_count} | Objects: {len(objects)}',
429                   (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
430
431        cv2.imshow('Centroid Tracking', frame)
432        cv2.imshow('Foreground Mask', fg_mask)
433
434        key = cv2.waitKey(30) & 0xFF
435        if key == ord('q'):
436            break
437        elif key == ord('p'):
438            cv2.waitKey(0)
439
440    cap.release()
441    cv2.destroyAllWindows()
442
443
444def demo_multi_tracker():
445    """Demonstrate tracking multiple objects simultaneously."""
446    # Generate synthetic video
447    video_path = '/tmp/multi_tracking_demo.avi'
448    generate_synthetic_video(video_path, num_frames=200)
449
450    cap = cv2.VideoCapture(video_path)
451    ret, frame = cap.read()
452
453    if not ret:
454        print("Error: Cannot read video")
455        return
456
457    # Create MultiTracker
458    multi_tracker = cv2.legacy.MultiTracker_create()
459
460    # Define multiple ROIs
461    bboxes = [
462        (50, 50, 40, 40),
463        (300, 100, 35, 35),
464        (100, 300, 45, 45)
465    ]
466
467    colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255)]
468
469    # Initialize trackers
470    for bbox in bboxes:
471        tracker = cv2.TrackerCSRT_create()
472        multi_tracker.add(tracker, frame, bbox)
473
474    frame_count = 0
475
476    while True:
477        ret, frame = cap.read()
478        if not ret:
479            break
480
481        frame_count += 1
482
483        # Update all trackers
484        success, boxes = multi_tracker.update(frame)
485
486        # Draw tracked objects
487        for i, box in enumerate(boxes):
488            x, y, w, h = [int(v) for v in box]
489            cv2.rectangle(frame, (x, y), (x+w, y+h), colors[i], 2)
490            cv2.putText(frame, f'Object {i+1}', (x, y-10),
491                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[i], 2)
492
493        cv2.putText(frame, f'Frame: {frame_count}', (10, 30),
494                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
495
496        cv2.imshow('Multi-Object Tracking', frame)
497
498        key = cv2.waitKey(30) & 0xFF
499        if key == ord('q'):
500            break
501
502    cap.release()
503    cv2.destroyAllWindows()
504
505
506if __name__ == "__main__":
507    print("Object Tracking Demonstrations")
508    print("=" * 50)
509    print("\nAvailable demos:")
510    print("1. OpenCV Trackers (KCF, CSRT, MOSSE) + Kalman Filter")
511    print("2. Centroid-Based Tracking")
512    print("3. Multi-Object Tracking")
513    print("\nPress Ctrl+C to exit anytime")
514
515    try:
516        # Demo 1: OpenCV trackers
517        print("\n[Demo 1] OpenCV Trackers with Kalman Filter")
518        print("-" * 50)
519        demo_opencv_trackers()
520
521        # Demo 2: Centroid tracking
522        print("\n[Demo 2] Centroid-Based Tracking")
523        print("-" * 50)
524        demo_centroid_tracking()
525
526        # Demo 3: Multi-tracker
527        print("\n[Demo 3] Multi-Object Tracking")
528        print("-" * 50)
529        demo_multi_tracker()
530
531    except KeyboardInterrupt:
532        print("\n\nDemo interrupted by user")
533    except Exception as e:
534        print(f"\nError: {e}")
535        import traceback
536        traceback.print_exc()
537
538    cv2.destroyAllWindows()
539    print("\nAll demos completed!")