SAM
Segment Anything introduced promptable segmentation with point, box, and mask prompts, trained on the large SA-1B mask dataset.
Semantic segmentation assigns a class label to every pixel. Students should understand the full path from image tensors to dense class maps, loss functions, metrics, and modern foundation models.
Segment Anything introduced promptable segmentation with point, box, and mask prompts, trained on the large SA-1B mask dataset.
SAM 2 extends promptable segmentation to images and videos with a streaming memory design for object masks across frames.
SEEM supports interactive segmentation from points, boxes, scribbles, masks, text, and referring expressions.
SegGPT frames segmentation as in-context visual prompting, learning to produce masks from example image-mask pairs.
OneFormer uses one transformer architecture for semantic, instance, and panoptic segmentation with task-conditioned training.
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class SegmentationDataset(Dataset):
def __init__(self, image_paths, mask_paths, image_transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.image_transform = image_transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
mask = Image.open(self.mask_paths[idx])
# Mask pixels must be integer class ids:
# 0=background, 1=road, 2=car, and so on.
mask = torch.as_tensor(np.array(mask), dtype=torch.long)
if self.image_transform:
image = self.image_transform(image)
return image, mask
import torch.nn as nn
from torchvision.models.segmentation import deeplabv3_resnet50
num_classes = 5 # include background as one class
model = deeplabv3_resnet50(weights=None, num_classes=num_classes)
# Output shape is [batch, num_classes, H, W].
# Each pixel has num_classes logits before softmax/argmax.
sample_logits = model(images)["out"]
import torch
import torch.nn.functional as F
epochs = 40
lr = 1e-4 # controls how large each optimizer update is
weight_decay = 1e-4 # discourages overfitting by penalizing large weights
ignore_index = 255 # pixels with this value do not affect the loss
optimizer = torch.optim.AdamW(
model.parameters(), # all model parameters are trainable here
lr=lr,
weight_decay=weight_decay
)
criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
for epoch in range(epochs):
model.train()
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
logits = model(images)["out"]
# If output size differs from mask size, resize logits.
logits = F.interpolate(
logits,
size=masks.shape[-2:],
mode="bilinear",
align_corners=False
)
# CrossEntropyLoss expects raw logits and integer masks.
loss = criterion(logits, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
import torch
def mean_iou(logits, masks, num_classes, ignore_index=255):
# Convert class logits to class id prediction per pixel.
preds = torch.argmax(logits, dim=1)
valid = masks != ignore_index
ious = []
for cls in range(num_classes):
pred_cls = (preds == cls) & valid
mask_cls = (masks == cls) & valid
intersection = (pred_cls & mask_cls).sum().float()
union = (pred_cls | mask_cls).sum().float()
if union > 0:
ious.append(intersection / union)
return torch.stack(ious).mean() if ious else torch.tensor(0.0)