23_panorama_stitching.py

Download
python 574 lines 18.0 KB
  1"""
  2Panorama Image Stitching
  3
  4This script demonstrates panorama creation from multiple images using:
  5- Feature detection (SIFT, ORB)
  6- Feature matching (BFMatcher, FLANN)
  7- Homography estimation with RANSAC
  8- Image warping and blending
  9
 10Requirements:
 11    pip install opencv-contrib-python numpy
 12
 13Author: Claude
 14Date: 2026-02-15
 15"""
 16
 17import cv2
 18import numpy as np
 19from typing import List, Tuple, Optional
 20import sys
 21
 22
 23class PanoramaStitcher:
 24    """
 25    Complete pipeline for creating panoramas from multiple images.
 26
 27    Supports various feature detectors and matchers, with configurable
 28    homography estimation and blending strategies.
 29    """
 30
 31    def __init__(self,
 32                 detector: str = 'sift',
 33                 matcher: str = 'flann',
 34                 ratio_thresh: float = 0.75,
 35                 ransac_reproj_thresh: float = 5.0,
 36                 blend_mode: str = 'linear'):
 37        """
 38        Args:
 39            detector: Feature detector ('sift', 'orb')
 40            matcher: Feature matcher ('flann', 'bf')
 41            ratio_thresh: Ratio test threshold for Lowe's ratio test
 42            ransac_reproj_thresh: RANSAC reprojection error threshold
 43            blend_mode: Blending mode ('linear', 'multiband')
 44        """
 45        self.detector_type = detector.lower()
 46        self.matcher_type = matcher.lower()
 47        self.ratio_thresh = ratio_thresh
 48        self.ransac_reproj_thresh = ransac_reproj_thresh
 49        self.blend_mode = blend_mode
 50
 51        # Initialize detector
 52        self.detector = self._create_detector()
 53
 54        # Initialize matcher
 55        self.matcher = self._create_matcher()
 56
 57    def _create_detector(self):
 58        """Create feature detector."""
 59        if self.detector_type == 'sift':
 60            try:
 61                return cv2.SIFT_create()
 62            except AttributeError:
 63                print("SIFT not available, falling back to ORB")
 64                self.detector_type = 'orb'
 65                return cv2.ORB_create(nfeatures=2000)
 66        elif self.detector_type == 'orb':
 67            return cv2.ORB_create(nfeatures=2000)
 68        else:
 69            raise ValueError(f"Unknown detector: {self.detector_type}")
 70
 71    def _create_matcher(self):
 72        """Create feature matcher."""
 73        if self.matcher_type == 'flann':
 74            if self.detector_type == 'sift':
 75                # FLANN parameters for SIFT
 76                index_params = dict(algorithm=1, trees=5)  # FLANN_INDEX_KDTREE
 77                search_params = dict(checks=50)
 78            else:
 79                # FLANN parameters for ORB
 80                index_params = dict(
 81                    algorithm=6,  # FLANN_INDEX_LSH
 82                    table_number=6,
 83                    key_size=12,
 84                    multi_probe_level=1
 85                )
 86                search_params = dict(checks=50)
 87
 88            return cv2.FlannBasedMatcher(index_params, search_params)
 89        elif self.matcher_type == 'bf':
 90            # BFMatcher with appropriate norm
 91            norm_type = cv2.NORM_L2 if self.detector_type == 'sift' else cv2.NORM_HAMMING
 92            return cv2.BFMatcher(norm_type, crossCheck=False)
 93        else:
 94            raise ValueError(f"Unknown matcher: {self.matcher_type}")
 95
 96    def detect_and_describe(self, image: np.ndarray) -> Tuple[List, np.ndarray]:
 97        """
 98        Detect keypoints and compute descriptors.
 99
100        Args:
101            image: Input image (grayscale or color)
102
103        Returns:
104            Tuple of (keypoints, descriptors)
105        """
106        if len(image.shape) == 3:
107            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
108        else:
109            gray = image
110
111        keypoints, descriptors = self.detector.detectAndCompute(gray, None)
112        return keypoints, descriptors
113
114    def match_features(self,
115                      desc1: np.ndarray,
116                      desc2: np.ndarray) -> List[cv2.DMatch]:
117        """
118        Match features between two images using ratio test.
119
120        Args:
121            desc1, desc2: Feature descriptors
122
123        Returns:
124            List of good matches
125        """
126        # Ensure descriptors are in correct format
127        if self.detector_type == 'orb' and desc1.dtype != np.uint8:
128            desc1 = desc1.astype(np.uint8)
129            desc2 = desc2.astype(np.uint8)
130        elif self.detector_type == 'sift' and desc1.dtype != np.float32:
131            desc1 = desc1.astype(np.float32)
132            desc2 = desc2.astype(np.float32)
133
134        # Find k=2 best matches for ratio test
135        matches = self.matcher.knnMatch(desc1, desc2, k=2)
136
137        # Apply Lowe's ratio test
138        good_matches = []
139        for match_pair in matches:
140            if len(match_pair) == 2:
141                m, n = match_pair
142                if m.distance < self.ratio_thresh * n.distance:
143                    good_matches.append(m)
144
145        return good_matches
146
147    def estimate_homography(self,
148                          kp1: List,
149                          kp2: List,
150                          matches: List[cv2.DMatch]) -> Tuple[Optional[np.ndarray], np.ndarray]:
151        """
152        Estimate homography matrix using RANSAC.
153
154        Args:
155            kp1, kp2: Keypoints from both images
156            matches: Good matches
157
158        Returns:
159            Tuple of (homography matrix, mask of inliers)
160        """
161        if len(matches) < 4:
162            print(f"Insufficient matches: {len(matches)} (need at least 4)")
163            return None, None
164
165        # Extract matched keypoint coordinates
166        pts1 = np.float32([kp1[m.queryIdx].pt for m in matches])
167        pts2 = np.float32([kp2[m.trainIdx].pt for m in matches])
168
169        # Compute homography with RANSAC
170        H, mask = cv2.findHomography(
171            pts1, pts2,
172            cv2.RANSAC,
173            self.ransac_reproj_thresh
174        )
175
176        if H is None:
177            print("Homography estimation failed")
178            return None, None
179
180        # Check for degenerate homography
181        if not self._is_valid_homography(H):
182            print("Degenerate homography detected")
183            return None, None
184
185        return H, mask
186
187    def _is_valid_homography(self, H: np.ndarray) -> bool:
188        """Check if homography is valid (not degenerate)."""
189        # Check determinant
190        det = np.linalg.det(H)
191        if abs(det) < 1e-6:
192            return False
193
194        # Check if transformation is too extreme
195        # (prevents warping to infinity)
196        h_norm = H / H[2, 2]  # Normalize
197        if abs(h_norm[2, 0]) > 0.001 or abs(h_norm[2, 1]) > 0.001:
198            # Perspective component too large
199            return False
200
201        return True
202
203    def warp_images(self,
204                   img1: np.ndarray,
205                   img2: np.ndarray,
206                   H: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
207        """
208        Warp img1 to align with img2 using homography H.
209
210        Args:
211            img1: Source image to warp
212            img2: Target/reference image
213            H: Homography matrix
214
215        Returns:
216            Tuple of (warped panorama, offset of img2 in panorama)
217        """
218        h1, w1 = img1.shape[:2]
219        h2, w2 = img2.shape[:2]
220
221        # Get corners of img1
222        corners1 = np.float32([
223            [0, 0],
224            [w1, 0],
225            [w1, h1],
226            [0, h1]
227        ]).reshape(-1, 1, 2)
228
229        # Get corners of img2
230        corners2 = np.float32([
231            [0, 0],
232            [w2, 0],
233            [w2, h2],
234            [0, h2]
235        ]).reshape(-1, 1, 2)
236
237        # Transform corners of img1
238        corners1_transformed = cv2.perspectiveTransform(corners1, H)
239
240        # Combine all corners
241        all_corners = np.concatenate((corners2, corners1_transformed), axis=0)
242
243        # Find bounding box
244        [x_min, y_min] = np.int32(all_corners.min(axis=0).ravel() - 0.5)
245        [x_max, y_max] = np.int32(all_corners.max(axis=0).ravel() + 0.5)
246
247        # Translation to bring all points into positive coordinates
248        translation = np.array([
249            [1, 0, -x_min],
250            [0, 1, -y_min],
251            [0, 0, 1]
252        ])
253
254        # Warp img1
255        output_size = (x_max - x_min, y_max - y_min)
256        img1_warped = cv2.warpPerspective(
257            img1,
258            translation @ H,
259            output_size,
260            flags=cv2.INTER_LINEAR
261        )
262
263        # Position where img2 will be placed
264        img2_offset = (-x_min, -y_min)
265
266        return img1_warped, img2_offset
267
268    def blend_images(self,
269                    img1_warped: np.ndarray,
270                    img2: np.ndarray,
271                    offset: Tuple[int, int]) -> np.ndarray:
272        """
273        Blend two images using specified blending mode.
274
275        Args:
276            img1_warped: Warped first image
277            img2: Second image
278            offset: Position to place img2 in the panorama
279
280        Returns:
281            Blended panorama
282        """
283        panorama = img1_warped.copy()
284        x_offset, y_offset = offset
285        h2, w2 = img2.shape[:2]
286
287        if self.blend_mode == 'linear':
288            # Linear blending in overlap region
289            # Create mask for img2
290            mask2 = np.zeros(panorama.shape[:2], dtype=np.float32)
291            mask2[y_offset:y_offset+h2, x_offset:x_offset+w2] = 1.0
292
293            # Create mask for img1 (existing content)
294            mask1 = (img1_warped.sum(axis=2) > 0).astype(np.float32)
295
296            # Find overlap region
297            overlap = (mask1 * mask2) > 0
298
299            if overlap.any():
300                # Compute distance transforms for smooth blending
301                dist1 = cv2.distanceTransform((mask1 > 0).astype(np.uint8),
302                                             cv2.DIST_L2, 5)
303                dist2 = cv2.distanceTransform((mask2 > 0).astype(np.uint8),
304                                             cv2.DIST_L2, 5)
305
306                # Normalize distances in overlap region
307                total_dist = dist1 + dist2 + 1e-6
308                alpha1 = dist1 / total_dist
309                alpha2 = dist2 / total_dist
310
311                # Blend
312                for c in range(3):
313                    panorama[overlap, c] = (
314                        alpha1[overlap] * img1_warped[overlap, c] +
315                        alpha2[overlap] * img2[y_offset:y_offset+h2,
316                                               x_offset:x_offset+w2][overlap[y_offset:y_offset+h2,
317                                                                            x_offset:x_offset+w2], c]
318                    ).astype(np.uint8)
319
320            # Place img2 in non-overlap region
321            non_overlap = (mask2 > 0) & (mask1 == 0)
322            panorama[non_overlap] = img2[non_overlap[y_offset:y_offset+h2,
323                                                     x_offset:x_offset+w2]]
324
325        else:  # Simple alpha blending
326            panorama[y_offset:y_offset+h2, x_offset:x_offset+w2] = img2
327
328        return panorama
329
330    def stitch(self,
331              images: List[np.ndarray],
332              visualize: bool = False) -> Optional[np.ndarray]:
333        """
334        Stitch multiple images into a panorama.
335
336        Args:
337            images: List of images to stitch (left to right order)
338            visualize: Whether to show intermediate results
339
340        Returns:
341            Stitched panorama or None if failed
342        """
343        if len(images) < 2:
344            print("Need at least 2 images to stitch")
345            return None
346
347        # Start with the middle image as base
348        result = images[len(images) // 2].copy()
349
350        # Stitch images to the right
351        for i in range(len(images) // 2 + 1, len(images)):
352            print(f"\nStitching image {i+1}/{len(images)}...")
353            result = self._stitch_pair(result, images[i], visualize)
354            if result is None:
355                print(f"Failed to stitch image {i+1}")
356                return None
357
358        # Stitch images to the left (reverse order)
359        for i in range(len(images) // 2 - 1, -1, -1):
360            print(f"\nStitching image {i+1}/{len(images)}...")
361            result = self._stitch_pair(images[i], result, visualize)
362            if result is None:
363                print(f"Failed to stitch image {i+1}")
364                return None
365
366        return result
367
368    def _stitch_pair(self,
369                    img1: np.ndarray,
370                    img2: np.ndarray,
371                    visualize: bool = False) -> Optional[np.ndarray]:
372        """Stitch a pair of images."""
373        # Detect and describe
374        kp1, desc1 = self.detect_and_describe(img1)
375        kp2, desc2 = self.detect_and_describe(img2)
376
377        print(f"  Features: {len(kp1)} (left), {len(kp2)} (right)")
378
379        if desc1 is None or desc2 is None:
380            print("  Feature detection failed")
381            return None
382
383        # Match features
384        matches = self.match_features(desc1, desc2)
385        print(f"  Good matches: {len(matches)}")
386
387        if visualize:
388            matches_img = cv2.drawMatches(
389                img1, kp1, img2, kp2, matches[:50], None,
390                flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS
391            )
392            cv2.imshow('Feature Matches', matches_img)
393            cv2.waitKey(500)
394
395        # Estimate homography
396        H, mask = self.estimate_homography(kp1, kp2, matches)
397
398        if H is None:
399            return None
400
401        inliers = np.sum(mask) if mask is not None else 0
402        print(f"  Inliers: {inliers}/{len(matches)}")
403
404        # Warp and blend
405        img1_warped, offset = self.warp_images(img1, img2, H)
406
407        if visualize:
408            cv2.imshow('Warped Image', img1_warped)
409            cv2.waitKey(500)
410
411        result = self.blend_images(img1_warped, img2, offset)
412
413        return result
414
415
416def generate_synthetic_images(num_images: int = 3,
417                             base_size: int = 400,
418                             overlap: float = 0.3) -> List[np.ndarray]:
419    """
420    Generate synthetic images for testing panorama stitching.
421
422    Creates images with a checkerboard pattern that are shifted/rotated
423    versions of each other.
424
425    Args:
426        num_images: Number of images to generate
427        base_size: Base image size
428        overlap: Overlap ratio between consecutive images
429
430    Returns:
431        List of generated images
432    """
433    images = []
434
435    # Create base pattern (checkerboard with gradients)
436    full_width = int(base_size * (1 + (num_images - 1) * (1 - overlap)))
437    canvas = np.zeros((base_size, full_width, 3), dtype=np.uint8)
438
439    # Draw checkerboard
440    square_size = 40
441    for i in range(0, base_size, square_size):
442        for j in range(0, full_width, square_size):
443            if ((i // square_size) + (j // square_size)) % 2 == 0:
444                canvas[i:i+square_size, j:j+square_size] = [100, 100, 100]
445            else:
446                canvas[i:i+square_size, j:j+square_size] = [200, 200, 200]
447
448    # Add some features (circles, rectangles)
449    for i in range(num_images * 3):
450        x = np.random.randint(50, full_width - 50)
451        y = np.random.randint(50, base_size - 50)
452        color = tuple(np.random.randint(0, 255, 3).tolist())
453
454        if np.random.random() > 0.5:
455            cv2.circle(canvas, (x, y), 20, color, -1)
456        else:
457            cv2.rectangle(canvas, (x-15, y-15), (x+15, y+15), color, -1)
458
459    # Add text
460    for i in range(num_images):
461        x = int(i * full_width / num_images) + 50
462        cv2.putText(canvas, f'Region {i+1}', (x, base_size // 2),
463                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
464
465    # Extract overlapping regions
466    step = int(base_size * (1 - overlap))
467    for i in range(num_images):
468        x_start = i * step
469        x_end = x_start + base_size
470        img = canvas[:, x_start:x_end].copy()
471
472        # Add slight rotation and noise for realism
473        if i > 0:
474            angle = np.random.uniform(-2, 2)
475            M = cv2.getRotationMatrix2D((base_size // 2, base_size // 2), angle, 1.0)
476            img = cv2.warpAffine(img, M, (base_size, base_size))
477
478        # Add Gaussian noise
479        noise = np.random.normal(0, 5, img.shape).astype(np.int16)
480        img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)
481
482        images.append(img)
483
484    return images
485
486
487def demo_basic_stitching():
488    """Demonstrate basic panorama stitching."""
489    print("Demo: Basic Panorama Stitching")
490    print("-" * 50)
491
492    # Generate synthetic images
493    print("Generating synthetic images...")
494    images = generate_synthetic_images(num_images=3, overlap=0.4)
495
496    # Display input images
497    combined_input = np.hstack(images)
498    cv2.imshow('Input Images', combined_input)
499    cv2.waitKey(1000)
500
501    # Create stitcher
502    stitcher = PanoramaStitcher(
503        detector='sift',
504        matcher='flann',
505        blend_mode='linear'
506    )
507
508    # Stitch panorama
509    print("\nStitching panorama...")
510    panorama = stitcher.stitch(images, visualize=True)
511
512    if panorama is not None:
513        print("\nPanorama created successfully!")
514        cv2.imshow('Panorama', panorama)
515        cv2.waitKey(0)
516    else:
517        print("\nFailed to create panorama")
518
519    cv2.destroyAllWindows()
520
521
522def demo_comparison():
523    """Compare different detector/matcher combinations."""
524    print("\nDemo: Detector/Matcher Comparison")
525    print("-" * 50)
526
527    # Generate synthetic images
528    images = generate_synthetic_images(num_images=2, overlap=0.5)
529
530    configurations = [
531        ('sift', 'flann'),
532        ('sift', 'bf'),
533        ('orb', 'bf'),
534    ]
535
536    for detector, matcher in configurations:
537        print(f"\n{detector.upper()} + {matcher.upper()}")
538        print("-" * 30)
539
540        try:
541            stitcher = PanoramaStitcher(detector=detector, matcher=matcher)
542            panorama = stitcher.stitch(images.copy())
543
544            if panorama is not None:
545                cv2.imshow(f'{detector.upper()}-{matcher.upper()}', panorama)
546                cv2.waitKey(1500)
547        except Exception as e:
548            print(f"Error: {e}")
549
550    cv2.waitKey(0)
551    cv2.destroyAllWindows()
552
553
554if __name__ == "__main__":
555    print("Panorama Stitching Demonstrations")
556    print("=" * 50)
557
558    try:
559        # Demo 1: Basic stitching
560        demo_basic_stitching()
561
562        # Demo 2: Comparison
563        demo_comparison()
564
565    except KeyboardInterrupt:
566        print("\n\nDemo interrupted by user")
567    except Exception as e:
568        print(f"\nError: {e}")
569        import traceback
570        traceback.print_exc()
571
572    cv2.destroyAllWindows()
573    print("\nAll demos completed!")