Segment Anything Model (SAM)
Learning Objectives
- Understand SAM's "Promptable Segmentation" paradigm
- Grasp the Image Encoder, Prompt Encoder, Mask Decoder structure
- Understand SAM's training data and methodology
- Learn practical SAM usage
1. SAM Overview
1.1 Foundation Model for Segmentation
SAM (Segment Anything Model) is a Vision Foundation Model released by Meta AI in 2023 that can segment any object in any image.
┌─────────────────────────────────────────────────────────────────┐
│ SAM's Innovation │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Traditional Segmentation: │
│ • Only specific classes (people, cars, etc.) │
│ • Only objects in training data │
│ • Model per class or fixed number of classes │
│ │
│ SAM: │
│ • Can segment any object │
│ • Specify desired object with prompts │
│ • Zero-shot: handles new objects immediately │
│ │
│ Prompt Types: │
│ ┌────────────────────────────────────────────────────┐ │
│ │ Point │ Click location (foreground/background) │ │
│ │ Box │ Bounding box │ │
│ │ Mask │ Rough mask (for refinement) │ │
│ │ Text │ Text description (SAM 2, Grounding SAM) │ │
│ └────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 SA-1B Dataset
┌─────────────────────────────────────────────────────────────────┐
│ SA-1B Dataset │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Scale: │
│ • 11M images │
│ • 1.1B (1.1 billion) masks │
│ • Average ~100 masks per image │
│ │
│ Collection Method (Data Engine): │
│ │
│ Phase 1: Assisted Manual (4.3M masks) │
│ ─────────────────────────────────── │
│ • Professional annotators label with SAM assistance │
│ • SAM proposes → humans correct │
│ │
│ Phase 2: Semi-Automatic (5.9M masks) │
│ ─────────────────────────────────── │
│ • SAM auto-generates confident masks │
│ • Humans only label the rest │
│ │
│ Phase 3: Fully Automatic (1.1B masks) │
│ ─────────────────────────────────── │
│ • Auto-generate with 32×32 grid points │
│ • Filter to select final masks │
│ │
└─────────────────────────────────────────────────────────────────┘
2. SAM Architecture
2.1 Overall Structure
┌─────────────────────────────────────────────────────────────────┐
│ SAM Architecture │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Input Image │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Image Encoder │ │
│ │ (MAE pre-trained ViT-H/16) │ │
│ │ │ │
│ │ • 1024×1024 input → 64×64 feature map │ │
│ │ • 632M parameters │ │
│ │ • Run only once (expensive) │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Image Embeddings │
│ (64×64×256) │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌───────────────────┐ ┌───────────────────┐ │
│ │ Prompt Encoder │ │ Prompt Encoder │ │
│ │ (Points/Boxes) │ │ (Dense: Mask) │ │
│ │ │ │ │ │
│ │ Sparse Embed │ │ Conv downscale │ │
│ │ (N×256) │ │ (256×64×64) │ │
│ └─────────┬─────────┘ └─────────┬─────────┘ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Mask Decoder │ │
│ │ (Lightweight Transformer) │ │
│ │ │ │
│ │ • 2-layer Transformer decoder │ │
│ │ • Cross-attention: prompt ↔ image │ │
│ │ • Self-attention: prompt tokens │ │
│ │ • 4M parameters (very lightweight) │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────┴─────────────┐ │
│ ▼ ▼ │
│ 3 Mask Outputs IoU Scores │
│ (256×256, upscaled) (confidence) │
│ │
└─────────────────────────────────────────────────────────────────┘
2.2 Image Encoder
"""
SAM Image Encoder: MAE pre-trained ViT-H
Features:
- ViT-H/16: 632M parameters
- Input: 1024×1024 (high resolution)
- Output: 64×64×256 feature map
- Positional Embedding: Windowed + Global attention
Why MAE pre-training?
- Masking-based learning suits dense prediction
- Self-supervised learning utilizes large-scale data
- Effective for patch-level representation learning
"""
import torch
import torch.nn as nn
class SAMImageEncoder(nn.Module):
"""
SAM's Image Encoder (simplified version)
Actually uses ViT-H, but this is
simplified for understanding the structure
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
embed_dim: int = 1280, # ViT-H
depth: int = 32,
num_heads: int = 16,
out_chans: int = 256,
):
super().__init__()
self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, patch_size)
self.pos_embed = nn.Parameter(
torch.zeros(1, (img_size // patch_size) ** 2, embed_dim)
)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads)
for _ in range(depth)
])
self.neck = nn.Sequential(
nn.Conv2d(embed_dim, out_chans, kernel_size=1),
nn.LayerNorm(out_chans),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1),
nn.LayerNorm(out_chans),
)
def forward(self, x):
# x: (B, 3, 1024, 1024)
x = self.patch_embed(x) # (B, embed_dim, 64, 64)
x = x.flatten(2).transpose(1, 2) # (B, 4096, embed_dim)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
# Reshape back to 2D
B, N, C = x.shape
H = W = int(N ** 0.5)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.neck(x) # (B, 256, 64, 64)
return x
2.3 Prompt Encoder
class SAMPromptEncoder(nn.Module):
"""
SAM Prompt Encoder
Prompt types:
1. Points: (x, y) + label (foreground/background)
2. Boxes: (x1, y1, x2, y2)
3. Masks: previous mask (for refinement)
"""
def __init__(self, embed_dim: int = 256, image_size: int = 1024):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
# Point embeddings
self.point_embeddings = nn.ModuleList([
nn.Embedding(1, embed_dim), # foreground
nn.Embedding(1, embed_dim), # background
])
# Positional encoding for points
self.pe_layer = PositionalEncoding(embed_dim, image_size)
# Box corner embeddings
self.box_embeddings = nn.Embedding(2, embed_dim) # top-left, bottom-right
# Mask encoder (for dense prompts)
self.mask_downscaler = nn.Sequential(
nn.Conv2d(1, embed_dim // 4, kernel_size=2, stride=2),
nn.LayerNorm(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
nn.LayerNorm(embed_dim),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, kernel_size=1),
)
# No-mask embedding
self.no_mask_embed = nn.Embedding(1, embed_dim)
def forward(self, points=None, boxes=None, masks=None):
"""
Args:
points: (B, N, 2) coordinates + (B, N) labels
boxes: (B, 4) bounding box
masks: (B, 1, H, W) previous mask
Returns:
sparse_embeddings: (B, N_prompts, embed_dim)
dense_embeddings: (B, embed_dim, H, W)
"""
sparse_embeddings = []
# Point prompts
if points is not None:
coords, labels = points
point_embed = self.pe_layer(coords) # positional encoding
for i in range(coords.shape[1]):
label = labels[:, i]
type_embed = self.point_embeddings[label](label)
sparse_embeddings.append(point_embed[:, i] + type_embed)
# Box prompts
if boxes is not None:
# Box = 2 corner points
corners = boxes.reshape(-1, 2, 2) # (B, 2, 2)
corner_embed = self.pe_layer(corners)
corner_embed += self.box_embeddings.weight
sparse_embeddings.extend([corner_embed[:, 0], corner_embed[:, 1]])
sparse_embeddings = torch.stack(sparse_embeddings, dim=1) if sparse_embeddings else None
# Dense prompt (mask)
if masks is not None:
dense_embeddings = self.mask_downscaler(masks)
else:
# No mask: learnable embedding
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1)
dense_embeddings = dense_embeddings.expand(-1, -1, 64, 64)
return sparse_embeddings, dense_embeddings
2.4 Mask Decoder
class SAMMaskDecoder(nn.Module):
"""
SAM Mask Decoder
Structure:
- 2-layer Transformer decoder
- Cross-attention: tokens ↔ image
- Self-attention: tokens
- 3 mask outputs (multi-scale)
- IoU prediction head
"""
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
num_mask_tokens: int = 4, # 3 masks + 1 IoU
):
super().__init__()
# Mask tokens (learnable)
self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim)
# Transformer layers
self.transformer = TwoWayTransformer(
depth=2,
embed_dim=embed_dim,
num_heads=num_heads,
)
# Output heads
self.iou_prediction_head = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, num_mask_tokens - 1), # 3 IoU scores
)
self.mask_prediction_head = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim // 4, kernel_size=2, stride=2),
nn.GELU(),
nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(embed_dim // 8, num_mask_tokens - 1, kernel_size=1),
)
def forward(self, image_embeddings, sparse_embeddings, dense_embeddings):
"""
Args:
image_embeddings: (B, 256, 64, 64)
sparse_embeddings: (B, N_prompts, 256)
dense_embeddings: (B, 256, 64, 64)
Returns:
masks: (B, 3, 256, 256)
iou_predictions: (B, 3)
"""
# Combine sparse and mask tokens
mask_tokens = self.mask_tokens.weight.unsqueeze(0).expand(
sparse_embeddings.shape[0], -1, -1
)
tokens = torch.cat([mask_tokens, sparse_embeddings], dim=1)
# Add dense embeddings to image
image_pe = dense_embeddings
src = image_embeddings + dense_embeddings
# Transformer decoder
# Cross-attention between tokens and image
tokens, src = self.transformer(tokens, src, image_pe)
# Extract mask tokens
mask_tokens_out = tokens[:, :self.mask_tokens.num_embeddings - 1]
# IoU prediction
iou_predictions = self.iou_prediction_head(mask_tokens_out[:, 0])
# Mask prediction
# Upscale and predict
src = src.reshape(-1, 256, 64, 64)
masks = self.mask_prediction_head(src) # (B, 3, 256, 256)
return masks, iou_predictions
class TwoWayTransformer(nn.Module):
"""
Two-way Transformer for SAM
Features:
- Token → Image cross-attention
- Image → Token cross-attention
- Token self-attention
"""
def __init__(self, depth, embed_dim, num_heads):
super().__init__()
self.layers = nn.ModuleList([
TwoWayAttentionBlock(embed_dim, num_heads)
for _ in range(depth)
])
def forward(self, tokens, image, image_pe):
for layer in self.layers:
tokens, image = layer(tokens, image, image_pe)
return tokens, image
3. Using SAM
3.1 Basic Usage
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
# Load model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device="cuda")
predictor = SamPredictor(sam)
# Set image
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
# Segment with point prompt
input_point = np.array([[500, 375]]) # click location
input_label = np.array([1]) # 1: foreground, 0: background
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, # output 3 masks
)
# Select mask with highest score
best_mask = masks[np.argmax(scores)]
3.2 Various Prompts
# 1. Multiple points
input_points = np.array([[500, 375], [600, 400], [450, 350]])
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
masks, scores, _ = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False, # single mask
)
# 2. Box prompt
input_box = np.array([100, 100, 500, 400]) # x1, y1, x2, y2
masks, scores, _ = predictor.predict(
box=input_box,
multimask_output=False,
)
# 3. Point + Box combined
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
# 4. Iterative refinement (using previous mask)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=logits[np.argmax(scores)][None, :, :], # previous logits
multimask_output=False,
)
3.3 Automatic Mask Generation
from segment_anything import SamAutomaticMaskGenerator
# Automatic mask generator
mask_generator = SamAutomaticMaskGenerator(
sam,
points_per_side=32, # 32×32 grid
pred_iou_thresh=0.88, # IoU threshold
stability_score_thresh=0.95, # stability threshold
min_mask_region_area=100, # minimum mask size
)
# Generate all masks in image
masks = mask_generator.generate(image)
# Result: list of dicts
# {
# 'segmentation': binary mask,
# 'area': mask area,
# 'bbox': bounding box,
# 'predicted_iou': IoU score,
# 'stability_score': stability score,
# 'crop_box': crop used for generation,
# }
print(f"Found {len(masks)} masks")
# Visualization
import matplotlib.pyplot as plt
def show_masks(image, masks):
plt.figure(figsize=(15, 10))
plt.imshow(image)
for mask in masks:
m = mask['segmentation']
color = np.random.random(3)
colored_mask = np.zeros((*m.shape, 4))
colored_mask[m] = [*color, 0.5]
plt.imshow(colored_mask)
plt.axis('off')
plt.show()
show_masks(image, masks)
from transformers import SamModel, SamProcessor
import torch
from PIL import Image
# Load model
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
# Load image
image = Image.open("image.jpg")
# Point prompt
input_points = [[[500, 375]]] # batch of points
inputs = processor(image, input_points=input_points, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Post-process
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
4. SAM 2 (2024)
4.1 SAM 2 Improvements
┌─────────────────────────────────────────────────────────────────┐
│ SAM vs SAM 2 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ SAM (2023): │
│ • Images only │
│ • Independent frame processing │
│ • Video: needs prompt for each frame │
│ │
│ SAM 2 (2024): │
│ • Unified images + video │
│ • Temporal consistency with memory attention │
│ • One prompt → track through entire video │
│ │
│ New Components: │
│ ┌────────────────────────────────────────────────────┐ │
│ │ Memory Encoder │ Encode past frame info │ │
│ │ Memory Bank │ Store past masks and features │ │
│ │ Memory Attention │ Current frame ↔ past info attn │ │
│ └────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
4.2 SAM 2 Video Usage
from sam2.build_sam import build_sam2_video_predictor
predictor = build_sam2_video_predictor(
"sam2_hiera_large.pt",
device="cuda"
)
# Load video frames
video_path = "video.mp4"
with predictor.init_state(video_path) as state:
# Prompt on first frame
_, _, masks = predictor.add_new_points_or_box(
state,
frame_idx=0,
obj_id=1,
points=[[500, 375]],
labels=[1],
)
# Auto-propagate to remaining frames
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
# masks: segmentation result for each frame
print(f"Frame {frame_idx}: {len(object_ids)} objects")
5. SAM Applications
5.1 Grounding SAM (Text → Segment)
"""
Grounding SAM = Grounding DINO + SAM
1. Grounding DINO: text → bounding box
2. SAM: bounding box → segmentation
Result: Segmentation from text prompts
"""
from groundingdino.util.inference import load_model, predict
from segment_anything import SamPredictor, sam_model_registry
# Detect boxes with Grounding DINO
grounding_dino = load_model("groundingdino_swinb.pth")
boxes, logits, phrases = predict(
grounding_dino,
image,
text_prompt="a cat",
box_threshold=0.3,
text_threshold=0.25,
)
# Segment with SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h.pth")
predictor = SamPredictor(sam)
predictor.set_image(image)
masks = []
for box in boxes:
mask, _, _ = predictor.predict(box=box.numpy(), multimask_output=False)
masks.append(mask)
"""
SAM-based Interactive Labeling Tool
1. Load image
2. User clicks points/boxes
3. SAM generates masks in real-time
4. User refines (positive/negative points)
5. Save final mask
"""
import cv2
import numpy as np
from segment_anything import SamPredictor
class SAMAnnotator:
def __init__(self, sam_checkpoint):
self.sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
self.predictor = SamPredictor(self.sam)
self.points = []
self.labels = []
def set_image(self, image):
self.image = image.copy()
self.predictor.set_image(image)
self.points = []
self.labels = []
def add_point(self, x, y, is_foreground=True):
self.points.append([x, y])
self.labels.append(1 if is_foreground else 0)
return self.predict()
def predict(self):
if not self.points:
return None
masks, scores, _ = self.predictor.predict(
point_coords=np.array(self.points),
point_labels=np.array(self.labels),
multimask_output=False,
)
return masks[0]
def reset(self):
self.points = []
self.labels = []
# Usage example (with OpenCV mouse callback)
# annotator = SAMAnnotator("sam_vit_h.pth")
# annotator.set_image(image)
# mask = annotator.add_point(500, 375, is_foreground=True)
5.3 Medical Imaging
"""
Medical Image Segmentation
SAM's strengths:
- Zero-shot segmentation of new organs/lesions
- Precise masks from expert point clicks
MedSAM: SAM fine-tuned on medical images
"""
# MedSAM usage example
from medsam import MedSAMPredictor
predictor = MedSAMPredictor("medsam_checkpoint.pth")
# Load CT/MRI image
medical_image = load_medical_image("ct_scan.nii")
# Slice-by-slice segmentation
for slice_idx in range(medical_image.shape[0]):
slice_img = medical_image[slice_idx]
predictor.set_image(slice_img)
# Expert clicks lesion location
mask, _, _ = predictor.predict(
point_coords=np.array([[tumor_x, tumor_y]]),
point_labels=np.array([1]),
)
Summary
SAM Key Components
| Component |
Role |
Features |
| Image Encoder |
Image feature extraction |
MAE ViT-H, 632M params |
| Prompt Encoder |
Prompt encoding |
Point/Box/Mask support |
| Mask Decoder |
Mask generation |
2-layer Transformer, 4M params |
Prompt Types
- Point: Click location (foreground/background)
- Box: Bounding box
- Mask: Previous mask (refinement)
- Text: Supported via Grounding SAM
Applications
| Use Case |
Method |
| Interactive Annotation |
Fast labeling with clicks |
| Automatic Segmentation |
Grid points for all objects |
| Video Tracking |
Object tracking with SAM 2 |
| Medical Imaging |
Specialized with MedSAM |
Next Steps
References
Papers
- Kirillov et al. (2023). "Segment Anything"
- Ravi et al. (2024). "SAM 2: Segment Anything in Images and Videos"
- Liu et al. (2023). "Grounding DINO"
- Ma et al. (2023). "Segment Anything in Medical Images" (MedSAM)
Code