Back to mlops
mlops v1.0.0 7.1 min read 503 lines

segment-anything-model

이미지 세그멘테이션 파운데이션 모델 — 제로샷 전이, 포인트/박스/마스크 프롬프트

Orchestra Research
MIT

Segment Anything Model (SAM)

Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.

When to use SAM

Use SAM when:

  • Need to segment any object in images without task-specific training
  • Building interactive annotation tools with point/box prompts
  • Generating training data for other vision models
  • Need zero-shot transfer to new image domains
  • Building object detection/segmentation pipelines
  • Processing medical, satellite, or domain-specific images

Key features:

  • Zero-shot segmentation: Works on any image domain without fine-tuning
  • Flexible prompts: Points, bounding boxes, or previous masks
  • Automatic segmentation: Generate all object masks automatically
  • High quality: Trained on 1.1 billion masks from 11 million images
  • Multiple model sizes: ViT-B (fastest), ViT-L, ViT-H (most accurate)
  • ONNX export: Deploy in browsers and edge devices

Use alternatives instead:

  • YOLO/Detectron2: For real-time object detection with classes
  • Mask2Former: For semantic/panoptic segmentation with categories
  • GroundingDINO + SAM: For text-prompted segmentation
  • SAM 2: For video segmentation tasks

Quick start

Installation

# From GitHub
pip install git+https://github.com/facebookresearch/segment-anything.git

Optional dependencies


pip install opencv-python pycocotools matplotlib

Or use HuggingFace transformers


pip install transformers

Download checkpoints

# ViT-H (largest, most accurate) - 2.4GB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

ViT-L (medium) - 1.2GB


wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

ViT-B (smallest, fastest) - 375MB


wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

Basic usage with SamPredictor

import numpy as np
from segment_anything import sam_model_registry, SamPredictor

Load model


sam = sam_model_registry"vit_h"
sam.to(device="cuda")

Create predictor


predictor = SamPredictor(sam)

Set image (computes embeddings once)


image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

Predict with point prompts


input_point = np.array([[500, 375]]) # (x, y) coordinates
input_label = np.array([1]) # 1 = foreground, 0 = background

masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True # Returns 3 mask options
)

Select best mask


best_mask = masks[np.argmax(scores)]

HuggingFace Transformers

import torch
from PIL import Image
from transformers import SamModel, SamProcessor

Load model and processor


model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to("cuda")

Process image with point prompt


image = Image.open("image.jpg")
input_points = [[[450, 600]]] # Batch of points

inputs = processor(image, input_points=input_points, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}

Generate masks


with torch.no_grad():
outputs = model(**inputs)

Post-process masks to original size


masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)

Core concepts

Model architecture

SAM Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
Image Embeddings Prompt Embeddings Masks + IoU
(computed once) (per prompt) predictions

Model variants

| Model | Checkpoint | Size | Speed | Accuracy |
|-------|------------|------|-------|----------|
| ViT-H | vit_h | 2.4 GB | Slowest | Best |
| ViT-L | vit_l | 1.2 GB | Medium | Good |
| ViT-B | vit_b | 375 MB | Fastest | Good |

Prompt types

| Prompt | Description | Use Case |
|--------|-------------|----------|
| Point (foreground) | Click on object | Single object selection |
| Point (background) | Click outside object | Exclude regions |
| Bounding box | Rectangle around object | Larger objects |
| Previous mask | Low-res mask input | Iterative refinement |

Interactive segmentation

Point prompts

# Single foreground point
input_point = np.array([[500, 375]])
input_label = np.array([1])

masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)

Multiple points (foreground + background)


input_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background

masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False # Single mask when prompts are clear
)

Box prompts

# Bounding box [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])

masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)

Combined prompts

# Box + points for precise control
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)

Iterative refinement

# Initial prediction
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)

Refine with additional point using previous mask


masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]), # Add background point
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
multimask_output=False
)

Automatic mask generation

Basic automatic segmentation

from segment_anything import SamAutomaticMaskGenerator

Create generator


mask_generator = SamAutomaticMaskGenerator(sam)

Generate all masks


masks = mask_generator.generate(image)

Each mask contains:


- segmentation: binary mask


- bbox: [x, y, w, h]


- area: pixel count


- predicted_iou: quality score


- stability_score: robustness score


- point_coords: generating point


Customized generation

mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # Grid density (more = more masks)
pred_iou_thresh=0.88, # Quality threshold
stability_score_thresh=0.95, # Stability threshold
crop_n_layers=1, # Multi-scale crops
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Remove tiny masks
)

masks = mask_generator.generate(image)

Filtering masks

# Sort by area (largest first)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)

Filter by predicted IoU


high_quality = [m for m in masks if m['predicted_iou'] > 0.9]

Filter by stability score


stable_masks = [m for m in masks if m['stability_score'] > 0.95]

Batched inference

Multiple images

# Process multiple images efficiently
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]

all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)

Multiple prompts per image

# Process multiple prompts efficiently (one image encoding)
predictor.set_image(image)

Batch of point prompts


points = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]

all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])

ONNX deployment

Export model

python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-mask

Use ONNX model

import onnxruntime

Load ONNX model


ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")

Run inference (image embeddings computed separately)


masks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)

Common workflows

Workflow 1: Annotation tool

import cv2

Load model


predictor = SamPredictor(sam)
predictor.set_image(image)

def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# Foreground point
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
# Display best mask
display_mask(masks[np.argmax(scores)])

Workflow 2: Object extraction

def extract_object(image, point):
"""Extract object at point with transparent background."""
predictor.set_image(image)

masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)

best_mask = masks[np.argmax(scores)]

# Create RGBA output
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255

return rgba

Workflow 3: Medical image segmentation

# Process medical images (grayscale to RGB)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)

predictor.set_image(rgb_image)

Segment region of interest


masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]), # ROI bounding box
multimask_output=True
)

Output format

Mask data structure

# SamAutomaticMaskGenerator output
{
"segmentation": np.ndarray, # H×W binary mask
"bbox": [x, y, w, h], # Bounding box
"area": int, # Pixel count
"predicted_iou": float, # 0-1 quality score
"stability_score": float, # 0-1 robustness score
"crop_box": [x, y, w, h], # Generation crop region
"point_coords": [[x, y]], # Input point
}

COCO RLE format

from pycocotools import mask as mask_utils

Encode mask to RLE


rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")

Decode RLE to mask


decoded_mask = mask_utils.decode(rle)

Performance optimization

GPU memory

# Use smaller model for limited VRAM
sam = sam_model_registry"vit_b"

Process images in batches


Clear CUDA cache between large batches


torch.cuda.empty_cache()

Speed optimization

# Use half precision
sam = sam.half()

Reduce points for automatic generation


mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Default is 32
)

Use ONNX for deployment


Export with --return-single-mask for faster inference


Common issues

| Issue | Solution |
|-------|----------|
| Out of memory | Use ViT-B model, reduce image size |
| Slow inference | Use ViT-B, reduce points_per_side |
| Poor mask quality | Try different prompts, use box + points |
| Edge artifacts | Use stability_score filtering |
| Small objects missed | Increase points_per_side |

References

Resources

  • GitHub: https://github.com/facebookresearch/segment-anything
  • Paper: https://arxiv.org/abs/2304.02643
  • Demo: https://segment-anything.com
  • SAM 2 (Video): https://github.com/facebookresearch/segment-anything-2
  • HuggingFace: https://huggingface.co/facebook/sam-vit-huge

Related Skills / 관련 스킬

mlops v1.0.0

modal-serverless-gpu

서버리스 GPU 클라우드 — ML 워크로드 온디맨드 GPU, 모델 API 배포, 자동 스케일링

mlops v1.0.0

evaluating-llms-harness

60개 이상 학술 벤치마크로 LLM 평가 — MMLU, HumanEval, GSM8K, TruthfulQA 등

mlops v1.0.0

weights-and-biases

W&B로 ML 실험 추적 — 자동 로깅, 실시간 시각화, 하이퍼파라미터 스윕, 모델 레지스트리

mlops v1.0.0

huggingface-hub

Hugging Face Hub CLI (hf) — 모델/데이터셋 검색, 다운로드, 업로드, Space 관리