Segment Anything Model (SAM)

Segment Anything Model (SAM)

ν•™μŠ΅ λͺ©ν‘œ

  • SAM의 "Promptable Segmentation" νŒ¨λŸ¬λ‹€μž„ 이해
  • Image Encoder, Prompt Encoder, Mask Decoder ꡬ쑰 νŒŒμ•…
  • SAM의 ν•™μŠ΅ 데이터와 방법둠 이해
  • μ‹€λ¬΄μ—μ„œ SAM ν™œμš©λ²• μŠ΅λ“

1. SAM κ°œμš”

1.1 Foundation Model for Segmentation

SAM (Segment Anything Model)은 Meta AIκ°€ 2023λ…„ λ°œν‘œν•œ Vision Foundation Model둜, μ–΄λ–€ μ΄λ―Έμ§€μ—μ„œλ“  μ–΄λ–€ 객체든 μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    SAM의 ν˜μ‹                                      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  κΈ°μ‘΄ μ„Έκ·Έλ©˜ν…Œμ΄μ…˜:                                               β”‚
β”‚  β€’ νŠΉμ • 클래슀만 (μ‚¬λžŒ, μžλ™μ°¨ λ“±)                                 β”‚
β”‚  β€’ ν•™μŠ΅ 데이터에 μžˆλŠ” 객체만                                       β”‚
β”‚  β€’ ν΄λž˜μŠ€λ³„ λͺ¨λΈ λ˜λŠ” κ³ μ •λœ 클래슀 수                              β”‚
β”‚                                                                 β”‚
β”‚  SAM:                                                           β”‚
β”‚  β€’ μ–΄λ–€ 객체든 μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ κ°€λŠ₯                                   β”‚
β”‚  β€’ ν”„λ‘¬ν”„νŠΈλ‘œ μ›ν•˜λŠ” 객체 μ§€μ •                                     β”‚
β”‚  β€’ Zero-shot: μƒˆλ‘œμš΄ 객체도 λ°”λ‘œ 처리                             β”‚
β”‚                                                                 β”‚
β”‚  ν”„λ‘¬ν”„νŠΈ μ’…λ₯˜:                                                   β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚
β”‚  β”‚ Point   β”‚ 클릭 μœ„μΉ˜ (foreground/background)        β”‚         β”‚
β”‚  β”‚ Box     β”‚ λ°”μš΄λ”© λ°•μŠ€                              β”‚         β”‚
β”‚  β”‚ Mask    β”‚ λŒ€λž΅μ μΈ 마슀크 (refinement)             β”‚         β”‚
β”‚  β”‚ Text    β”‚ ν…μŠ€νŠΈ μ„€λͺ… (SAM 2, Grounding SAM)      β”‚         β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 SA-1B 데이터셋

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    SA-1B Dataset                                 β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  규λͺ¨:                                                           β”‚
β”‚  β€’ 11M 이미지                                                    β”‚
β”‚  β€’ 1.1B (11μ–΅) 마슀크                                            β”‚
β”‚  β€’ 이미지당 평균 ~100 마슀크                                      β”‚
β”‚                                                                 β”‚
β”‚  μˆ˜μ§‘ 방법 (Data Engine):                                        β”‚
β”‚                                                                 β”‚
β”‚  Phase 1: Assisted Manual (4.3M masks)                          β”‚
β”‚  ───────────────────────────────────                            β”‚
β”‚  β€’ μ „λ¬Έ annotatorκ°€ SAM 도움받아 λ ˆμ΄λΈ”λ§                          β”‚
β”‚  β€’ SAM이 μ œμ•ˆ β†’ μ‚¬λžŒμ΄ μˆ˜μ •                                       β”‚
β”‚                                                                 β”‚
β”‚  Phase 2: Semi-Automatic (5.9M masks)                           β”‚
β”‚  ───────────────────────────────────                            β”‚
β”‚  β€’ SAM이 confidentν•œ 마슀크 μžλ™ 생성                              β”‚
β”‚  β€’ μ‚¬λžŒμ€ λ‚˜λ¨Έμ§€λ§Œ λ ˆμ΄λΈ”λ§                                        β”‚
β”‚                                                                 β”‚
β”‚  Phase 3: Fully Automatic (1.1B masks)                          β”‚
β”‚  ───────────────────────────────────                            β”‚
β”‚  β€’ 32Γ—32 grid points둜 μžλ™ 생성                                 β”‚
β”‚  β€’ 필터링 ν›„ μ΅œμ’… 마슀크 선별                                      β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2. SAM μ•„ν‚€ν…μ²˜

2.1 전체 ꡬ쑰

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    SAM Architecture                              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚                         Input Image                             β”‚
β”‚                              β”‚                                  β”‚
β”‚                              β–Ό                                  β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚
β”‚  β”‚                   Image Encoder                          β”‚    β”‚
β”‚  β”‚           (MAE pre-trained ViT-H/16)                    β”‚    β”‚
β”‚  β”‚                                                          β”‚    β”‚
β”‚  β”‚  β€’ 1024Γ—1024 μž…λ ₯ β†’ 64Γ—64 feature map                   β”‚    β”‚
β”‚  β”‚  β€’ 632M parameters                                       β”‚    β”‚
β”‚  β”‚  β€’ ν•œ 번만 μ‹€ν–‰ (λΉ„μš© 큼)                                 β”‚    β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚
β”‚                              β”‚                                  β”‚
β”‚                              β–Ό                                  β”‚
β”‚                     Image Embeddings                            β”‚
β”‚                       (64Γ—64Γ—256)                               β”‚
β”‚                              β”‚                                  β”‚
β”‚              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
β”‚              β”‚                               β”‚                  β”‚
β”‚              β–Ό                               β–Ό                  β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”           β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚
β”‚  β”‚  Prompt Encoder   β”‚           β”‚  Prompt Encoder   β”‚         β”‚
β”‚  β”‚  (Points/Boxes)   β”‚           β”‚  (Dense: Mask)    β”‚         β”‚
β”‚  β”‚                   β”‚           β”‚                   β”‚         β”‚
β”‚  β”‚  Sparse Embed     β”‚           β”‚  Conv downscale   β”‚         β”‚
β”‚  β”‚  (NΓ—256)          β”‚           β”‚  (256Γ—64Γ—64)      β”‚         β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜           β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         β”‚
β”‚            β”‚                               β”‚                    β”‚
β”‚            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β”‚
β”‚                            β–Ό                                    β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚
β”‚  β”‚                   Mask Decoder                           β”‚    β”‚
β”‚  β”‚           (Lightweight Transformer)                      β”‚    β”‚
β”‚  β”‚                                                          β”‚    β”‚
β”‚  β”‚  β€’ 2-layer Transformer decoder                          β”‚    β”‚
β”‚  β”‚  β€’ Cross-attention: prompt ↔ image                      β”‚    β”‚
β”‚  β”‚  β€’ Self-attention: prompt tokens                        β”‚    β”‚
β”‚  β”‚  β€’ 4M parameters (맀우 가벼움)                           β”‚    β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚
β”‚                            β”‚                                    β”‚
β”‚              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                      β”‚
β”‚              β–Ό                           β–Ό                      β”‚
β”‚         3 Mask Outputs             IoU Scores                   β”‚
β”‚     (256Γ—256, upscaled)          (confidence)                   β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2.2 Image Encoder

"""
SAM Image Encoder: MAE pre-trained ViT-H

νŠΉμ§•:
- ViT-H/16: 632M parameters
- μž…λ ₯: 1024Γ—1024 (고해상도)
- 좜λ ₯: 64Γ—64Γ—256 feature map
- Positional Embedding: Windowed + Global attention

μ™œ MAE pre-training?
- λ§ˆμŠ€ν‚Ή 기반 ν•™μŠ΅μœΌλ‘œ dense prediction에 적합
- 자기 지도 ν•™μŠ΅μœΌλ‘œ λŒ€κ·œλͺ¨ 데이터 ν™œμš©
- Patch-level ν‘œν˜„ ν•™μŠ΅μ— 효과적
"""

import torch
import torch.nn as nn

class SAMImageEncoder(nn.Module):
    """
    SAM의 Image Encoder (κ°„μ†Œν™” 버전)

    μ‹€μ œλ‘œλŠ” ViT-Hλ₯Ό μ‚¬μš©ν•˜μ§€λ§Œ,
    μ—¬κΈ°μ„œλŠ” ꡬ쑰 이해λ₯Ό μœ„ν•œ κ°„μ†Œν™”
    """
    def __init__(
        self,
        img_size: int = 1024,
        patch_size: int = 16,
        embed_dim: int = 1280,  # ViT-H
        depth: int = 32,
        num_heads: int = 16,
        out_chans: int = 256,
    ):
        super().__init__()

        self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, patch_size)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, (img_size // patch_size) ** 2, embed_dim)
        )

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        self.neck = nn.Sequential(
            nn.Conv2d(embed_dim, out_chans, kernel_size=1),
            nn.LayerNorm(out_chans),
            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1),
            nn.LayerNorm(out_chans),
        )

    def forward(self, x):
        # x: (B, 3, 1024, 1024)
        x = self.patch_embed(x)  # (B, embed_dim, 64, 64)
        x = x.flatten(2).transpose(1, 2)  # (B, 4096, embed_dim)
        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        # Reshape back to 2D
        B, N, C = x.shape
        H = W = int(N ** 0.5)
        x = x.transpose(1, 2).reshape(B, C, H, W)

        x = self.neck(x)  # (B, 256, 64, 64)
        return x

2.3 Prompt Encoder

class SAMPromptEncoder(nn.Module):
    """
    SAM Prompt Encoder

    ν”„λ‘¬ν”„νŠΈ μ’…λ₯˜:
    1. Points: (x, y) + label (foreground/background)
    2. Boxes: (x1, y1, x2, y2)
    3. Masks: 이전 마슀크 (refinement용)
    """
    def __init__(self, embed_dim: int = 256, image_size: int = 1024):
        super().__init__()
        self.embed_dim = embed_dim
        self.image_size = image_size

        # Point embeddings
        self.point_embeddings = nn.ModuleList([
            nn.Embedding(1, embed_dim),  # foreground
            nn.Embedding(1, embed_dim),  # background
        ])

        # Positional encoding for points
        self.pe_layer = PositionalEncoding(embed_dim, image_size)

        # Box corner embeddings
        self.box_embeddings = nn.Embedding(2, embed_dim)  # top-left, bottom-right

        # Mask encoder (for dense prompts)
        self.mask_downscaler = nn.Sequential(
            nn.Conv2d(1, embed_dim // 4, kernel_size=2, stride=2),
            nn.LayerNorm(embed_dim // 4),
            nn.GELU(),
            nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1),
        )

        # No-mask embedding
        self.no_mask_embed = nn.Embedding(1, embed_dim)

    def forward(self, points=None, boxes=None, masks=None):
        """
        Args:
            points: (B, N, 2) μ’Œν‘œ + (B, N) λ ˆμ΄λΈ”
            boxes: (B, 4) λ°”μš΄λ”© λ°•μŠ€
            masks: (B, 1, H, W) 이전 마슀크

        Returns:
            sparse_embeddings: (B, N_prompts, embed_dim)
            dense_embeddings: (B, embed_dim, H, W)
        """
        sparse_embeddings = []

        # Point prompts
        if points is not None:
            coords, labels = points
            point_embed = self.pe_layer(coords)  # positional encoding

            for i in range(coords.shape[1]):
                label = labels[:, i]
                type_embed = self.point_embeddings[label](label)
                sparse_embeddings.append(point_embed[:, i] + type_embed)

        # Box prompts
        if boxes is not None:
            # Box = 2 corner points
            corners = boxes.reshape(-1, 2, 2)  # (B, 2, 2)
            corner_embed = self.pe_layer(corners)
            corner_embed += self.box_embeddings.weight
            sparse_embeddings.extend([corner_embed[:, 0], corner_embed[:, 1]])

        sparse_embeddings = torch.stack(sparse_embeddings, dim=1) if sparse_embeddings else None

        # Dense prompt (mask)
        if masks is not None:
            dense_embeddings = self.mask_downscaler(masks)
        else:
            # No mask: learnable embedding
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1)
            dense_embeddings = dense_embeddings.expand(-1, -1, 64, 64)

        return sparse_embeddings, dense_embeddings

2.4 Mask Decoder

class SAMMaskDecoder(nn.Module):
    """
    SAM Mask Decoder

    ꡬ쑰:
    - 2-layer Transformer decoder
    - Cross-attention: tokens ↔ image
    - Self-attention: tokens
    - 3개의 마슀크 좜λ ₯ (multi-scale)
    - IoU prediction head
    """
    def __init__(
        self,
        embed_dim: int = 256,
        num_heads: int = 8,
        num_mask_tokens: int = 4,  # 3 masks + 1 IoU
    ):
        super().__init__()

        # Mask tokens (learnable)
        self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim)

        # Transformer layers
        self.transformer = TwoWayTransformer(
            depth=2,
            embed_dim=embed_dim,
            num_heads=num_heads,
        )

        # Output heads
        self.iou_prediction_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, num_mask_tokens - 1),  # 3 IoU scores
        )

        self.mask_prediction_head = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim // 4, kernel_size=2, stride=2),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, kernel_size=2, stride=2),
            nn.GELU(),
            nn.Conv2d(embed_dim // 8, num_mask_tokens - 1, kernel_size=1),
        )

    def forward(self, image_embeddings, sparse_embeddings, dense_embeddings):
        """
        Args:
            image_embeddings: (B, 256, 64, 64)
            sparse_embeddings: (B, N_prompts, 256)
            dense_embeddings: (B, 256, 64, 64)

        Returns:
            masks: (B, 3, 256, 256)
            iou_predictions: (B, 3)
        """
        # Combine sparse and mask tokens
        mask_tokens = self.mask_tokens.weight.unsqueeze(0).expand(
            sparse_embeddings.shape[0], -1, -1
        )
        tokens = torch.cat([mask_tokens, sparse_embeddings], dim=1)

        # Add dense embeddings to image
        image_pe = dense_embeddings
        src = image_embeddings + dense_embeddings

        # Transformer decoder
        # Cross-attention between tokens and image
        tokens, src = self.transformer(tokens, src, image_pe)

        # Extract mask tokens
        mask_tokens_out = tokens[:, :self.mask_tokens.num_embeddings - 1]

        # IoU prediction
        iou_predictions = self.iou_prediction_head(mask_tokens_out[:, 0])

        # Mask prediction
        # Upscale and predict
        src = src.reshape(-1, 256, 64, 64)
        masks = self.mask_prediction_head(src)  # (B, 3, 256, 256)

        return masks, iou_predictions


class TwoWayTransformer(nn.Module):
    """
    Two-way Transformer for SAM

    νŠΉμ§•:
    - Token β†’ Image cross-attention
    - Image β†’ Token cross-attention
    - Token self-attention
    """
    def __init__(self, depth, embed_dim, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([
            TwoWayAttentionBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

    def forward(self, tokens, image, image_pe):
        for layer in self.layers:
            tokens, image = layer(tokens, image, image_pe)
        return tokens, image

3. SAM μ‚¬μš©ν•˜κΈ°

3.1 κΈ°λ³Έ μ‚¬μš©λ²•

from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np

# λͺ¨λΈ λ‘œλ“œ
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device="cuda")
predictor = SamPredictor(sam)

# 이미지 μ„€μ •
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

# Point prompt둜 μ„Έκ·Έλ©˜ν…Œμ΄μ…˜
input_point = np.array([[500, 375]])  # 클릭 μœ„μΉ˜
input_label = np.array([1])  # 1: foreground, 0: background

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,  # 3개 마슀크 좜λ ₯
)

# κ°€μž₯ 높은 score의 마슀크 선택
best_mask = masks[np.argmax(scores)]

3.2 λ‹€μ–‘ν•œ ν”„λ‘¬ν”„νŠΈ

# 1. Multiple points
input_points = np.array([[500, 375], [600, 400], [450, 350]])
input_labels = np.array([1, 1, 0])  # 2 foreground, 1 background

masks, scores, _ = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    multimask_output=False,  # 단일 마슀크
)

# 2. Box prompt
input_box = np.array([100, 100, 500, 400])  # x1, y1, x2, y2

masks, scores, _ = predictor.predict(
    box=input_box,
    multimask_output=False,
)

# 3. Point + Box combined
masks, scores, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)

# 4. Iterative refinement (이전 마슀크 μ‚¬μš©)
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=logits[np.argmax(scores)][None, :, :],  # 이전 logits
    multimask_output=False,
)

3.3 Automatic Mask Generation

from segment_anything import SamAutomaticMaskGenerator

# μžλ™ 마슀크 생성기
mask_generator = SamAutomaticMaskGenerator(
    sam,
    points_per_side=32,           # 32Γ—32 grid
    pred_iou_thresh=0.88,         # IoU μž„κ³„κ°’
    stability_score_thresh=0.95,  # μ•ˆμ •μ„± μž„κ³„κ°’
    min_mask_region_area=100,     # μ΅œμ†Œ 마슀크 크기
)

# μ΄λ―Έμ§€μ˜ λͺ¨λ“  마슀크 생성
masks = mask_generator.generate(image)

# κ²°κ³Ό: list of dicts
# {
#     'segmentation': binary mask,
#     'area': mask area,
#     'bbox': bounding box,
#     'predicted_iou': IoU score,
#     'stability_score': stability score,
#     'crop_box': crop used for generation,
# }

print(f"Found {len(masks)} masks")

# μ‹œκ°ν™”
import matplotlib.pyplot as plt

def show_masks(image, masks):
    plt.figure(figsize=(15, 10))
    plt.imshow(image)
    for mask in masks:
        m = mask['segmentation']
        color = np.random.random(3)
        colored_mask = np.zeros((*m.shape, 4))
        colored_mask[m] = [*color, 0.5]
        plt.imshow(colored_mask)
    plt.axis('off')
    plt.show()

show_masks(image, masks)

3.4 HuggingFace Transformers μ‚¬μš©

from transformers import SamModel, SamProcessor
import torch
from PIL import Image

# λͺ¨λΈ λ‘œλ“œ
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# 이미지 λ‘œλ“œ
image = Image.open("image.jpg")

# Point prompt
input_points = [[[500, 375]]]  # batch of points

inputs = processor(image, input_points=input_points, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# Post-process
masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu()
)

scores = outputs.iou_scores

4. SAM 2 (2024)

4.1 SAM 2의 λ°œμ „

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    SAM vs SAM 2                                  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  SAM (2023):                                                    β”‚
β”‚  β€’ 이미지 μ „μš©                                                   β”‚
β”‚  β€’ ν”„λ ˆμž„λ³„ 독립 처리                                            β”‚
β”‚  β€’ λΉ„λ””μ˜€: ν”„λ ˆμž„λ§ˆλ‹€ ν”„λ‘¬ν”„νŠΈ ν•„μš”                               β”‚
β”‚                                                                 β”‚
β”‚  SAM 2 (2024):                                                  β”‚
β”‚  β€’ 이미지 + λΉ„λ””μ˜€ 톡합                                          β”‚
β”‚  β€’ Memory attention으둜 μ‹œκ°„ 일관성                              β”‚
β”‚  β€’ ν•œ 번 ν”„λ‘¬ν”„νŠΈ β†’ 전체 λΉ„λ””μ˜€ 좔적                              β”‚
β”‚                                                                 β”‚
β”‚  μƒˆλ‘œμš΄ κ΅¬μ„±μš”μ†Œ:                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚
β”‚  β”‚ Memory Encoder   β”‚ κ³Όκ±° ν”„λ ˆμž„ 정보 인코딩          β”‚         β”‚
β”‚  β”‚ Memory Bank      β”‚ κ³Όκ±° λ§ˆμŠ€ν¬μ™€ νŠΉμ§• μ €μž₯          β”‚         β”‚
β”‚  β”‚ Memory Attention β”‚ ν˜„μž¬ ν”„λ ˆμž„ ↔ κ³Όκ±° 정보 attentionβ”‚         β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

4.2 SAM 2 λΉ„λ””μ˜€ μ‚¬μš©

from sam2.build_sam import build_sam2_video_predictor

predictor = build_sam2_video_predictor(
    "sam2_hiera_large.pt",
    device="cuda"
)

# λΉ„λ””μ˜€ ν”„λ ˆμž„λ“€ λ‘œλ“œ
video_path = "video.mp4"

with predictor.init_state(video_path) as state:
    # 첫 ν”„λ ˆμž„μ—μ„œ ν”„λ‘¬ν”„νŠΈ
    _, _, masks = predictor.add_new_points_or_box(
        state,
        frame_idx=0,
        obj_id=1,
        points=[[500, 375]],
        labels=[1],
    )

    # λ‚˜λ¨Έμ§€ ν”„λ ˆμž„ μžλ™ μ „νŒŒ
    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
        # masks: 각 ν”„λ ˆμž„μ˜ μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ κ²°κ³Ό
        print(f"Frame {frame_idx}: {len(object_ids)} objects")

5. SAM μ‘μš©

5.1 Grounding SAM (Text β†’ Segment)

"""
Grounding SAM = Grounding DINO + SAM

1. Grounding DINO: ν…μŠ€νŠΈ β†’ λ°”μš΄λ”© λ°•μŠ€
2. SAM: λ°”μš΄λ”© λ°•μŠ€ β†’ μ„Έκ·Έλ©˜ν…Œμ΄μ…˜

κ²°κ³Ό: ν…μŠ€νŠΈ ν”„λ‘¬ν”„νŠΈλ‘œ μ„Έκ·Έλ©˜ν…Œμ΄μ…˜
"""

from groundingdino.util.inference import load_model, predict
from segment_anything import SamPredictor, sam_model_registry

# Grounding DINO둜 λ°•μŠ€ κ²€μΆœ
grounding_dino = load_model("groundingdino_swinb.pth")
boxes, logits, phrases = predict(
    grounding_dino,
    image,
    text_prompt="a cat",
    box_threshold=0.3,
    text_threshold=0.25,
)

# SAM으둜 μ„Έκ·Έλ©˜ν…Œμ΄μ…˜
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h.pth")
predictor = SamPredictor(sam)
predictor.set_image(image)

masks = []
for box in boxes:
    mask, _, _ = predictor.predict(box=box.numpy(), multimask_output=False)
    masks.append(mask)

5.2 Interactive Annotation Tool

"""
SAM 기반 μΈν„°λž™ν‹°λΈŒ λ ˆμ΄λΈ”λ§ 도ꡬ

1. 이미지 λ‘œλ“œ
2. μ‚¬μš©μžκ°€ 포인트/λ°•μŠ€ 클릭
3. SAM이 μ‹€μ‹œκ°„ 마슀크 생성
4. μ‚¬μš©μžκ°€ μˆ˜μ • (positive/negative points)
5. μ΅œμ’… 마슀크 μ €μž₯
"""

import cv2
import numpy as np
from segment_anything import SamPredictor

class SAMAnnotator:
    def __init__(self, sam_checkpoint):
        self.sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
        self.predictor = SamPredictor(self.sam)
        self.points = []
        self.labels = []

    def set_image(self, image):
        self.image = image.copy()
        self.predictor.set_image(image)
        self.points = []
        self.labels = []

    def add_point(self, x, y, is_foreground=True):
        self.points.append([x, y])
        self.labels.append(1 if is_foreground else 0)
        return self.predict()

    def predict(self):
        if not self.points:
            return None

        masks, scores, _ = self.predictor.predict(
            point_coords=np.array(self.points),
            point_labels=np.array(self.labels),
            multimask_output=False,
        )
        return masks[0]

    def reset(self):
        self.points = []
        self.labels = []

# μ‚¬μš© μ˜ˆμ‹œ (OpenCV 마우슀 콜백과 ν•¨κ»˜)
# annotator = SAMAnnotator("sam_vit_h.pth")
# annotator.set_image(image)
# mask = annotator.add_point(500, 375, is_foreground=True)

5.3 Medical Imaging

"""
의료 μ˜μƒ μ„Έκ·Έλ©˜ν…Œμ΄μ…˜

SAM의 강점:
- Zero-shot으둜 μƒˆλ‘œμš΄ μž₯κΈ°/병변 μ„Έκ·Έλ©˜ν…Œμ΄μ…˜
- μ „λ¬Έκ°€μ˜ 포인트 클릭만으둜 μ •λ°€ 마슀크

MedSAM: 의료 μ˜μƒμ— fine-tuned SAM
"""

# MedSAM μ‚¬μš© μ˜ˆμ‹œ
from medsam import MedSAMPredictor

predictor = MedSAMPredictor("medsam_checkpoint.pth")

# CT/MRI 이미지 λ‘œλ“œ
medical_image = load_medical_image("ct_scan.nii")

# μŠ¬λΌμ΄μŠ€λ³„ μ„Έκ·Έλ©˜ν…Œμ΄μ…˜
for slice_idx in range(medical_image.shape[0]):
    slice_img = medical_image[slice_idx]
    predictor.set_image(slice_img)

    # μ „λ¬Έκ°€κ°€ 병변 μœ„μΉ˜ 클릭
    mask, _, _ = predictor.predict(
        point_coords=np.array([[tumor_x, tumor_y]]),
        point_labels=np.array([1]),
    )

정리

SAM 핡심 ꡬ성

κ΅¬μ„±μš”μ†Œ μ—­ν•  νŠΉμ§•
Image Encoder 이미지 νŠΉμ§• μΆ”μΆœ MAE ViT-H, 632M params
Prompt Encoder ν”„λ‘¬ν”„νŠΈ 인코딩 Point/Box/Mask 지원
Mask Decoder 마슀크 생성 2-layer Transformer, 4M params

ν”„λ‘¬ν”„νŠΈ μ’…λ₯˜

  • Point: 클릭 μœ„μΉ˜ (foreground/background)
  • Box: λ°”μš΄λ”© λ°•μŠ€
  • Mask: 이전 마슀크 (refinement)
  • Text: Grounding SAM 톡해 지원

ν™œμš©

μš©λ„ 방법
Interactive Annotation 클릭으둜 λΉ λ₯Έ λ ˆμ΄λΈ”λ§
Automatic Segmentation Grid points둜 전체 객체
Video Tracking SAM 2둜 객체 좔적
Medical Imaging MedSAM으둜 νŠΉν™”

λ‹€μŒ 단계


참고 자료

λ…Όλ¬Έ

  • Kirillov et al. (2023). "Segment Anything"
  • Ravi et al. (2024). "SAM 2: Segment Anything in Images and Videos"
  • Liu et al. (2023). "Grounding DINO"
  • Ma et al. (2023). "Segment Anything in Medical Images" (MedSAM)

μ½”λ“œ

to navigate between lessons