CSE 438 Digital Image Processing
Self-Supervised Learning + Semantic Segmentation
This page explains how an SSL pretrained encoder becomes the backbone of a dense prediction model, reducing label demand and improving transfer when segmentation masks are limited.
Unlabeled images
to
SSL backbone
to
Segmentation decoder
to
Pixel mask
How the Combined Logic Works
Stage 1
Pretrain encoder
DINO, MAE, JEPA, or BYOL learns visual features from unlabeled images using crops, masks, targets, or teacher-student matching.
Stage 2
Add decoder
A segmentation decoder upsamples encoder features and produces class logits for every pixel.
Stage 3
Fine-tune masks
Supervised training uses labeled masks. The backbone can be frozen first, then unfrozen with a smaller learning rate.
PyTorch Code With Comments
import torch
import torch.nn as nn
from torchvision.models import vit_b_16
class SSLViTBackbone(nn.Module):
def __init__(self, checkpoint_path):
super().__init__()
self.encoder = vit_b_16(weights=None)
self.encoder.heads = nn.Identity()
# Load weights learned from an SSL method such as MAE or DINO.
# strict=False is useful when the checkpoint also contains a projector.
state = torch.load(checkpoint_path, map_location="cpu")
self.encoder.load_state_dict(state, strict=False)
def forward(self, x):
# A real dense ViT implementation should return patch tokens.
# This simplified example returns a global feature vector.
return self.encoder(x)
class SimpleSegmentationHead(nn.Module):
def __init__(self, in_channels=768, num_classes=6):
super().__init__()
self.proj = nn.Sequential(
# 1x1 convolution maps SSL feature channels to class channels.
nn.Conv2d(in_channels, 256, kernel_size=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, kernel_size=1)
)
def forward(self, features, image_size):
logits = self.proj(features)
# Upsample low-resolution logits to the original mask size.
return nn.functional.interpolate(
logits,
size=image_size,
mode="bilinear",
align_corners=False
)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.AdamW([
{
"params": backbone.parameters(),
"lr": 1e-5 # small LR protects pretrained SSL features
},
{
"params": decoder.parameters(),
"lr": 1e-4 # larger LR because decoder starts from scratch
}
], weight_decay=1e-4)
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
features = backbone(images)
logits = decoder(features, image_size=masks.shape[-2:])
loss = criterion(logits, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Phase 1: train only the decoder.
for param in backbone.parameters():
param.requires_grad = False
for param in decoder.parameters():
param.requires_grad = True
# Phase 2: unfreeze the last encoder blocks.
# This adapts SSL features to the segmentation dataset.
for name, param in backbone.named_parameters():
if "encoder.layers.encoder_layer_10" in name or "encoder.layers.encoder_layer_11" in name:
param.requires_grad = True
# Always build the optimizer after changing requires_grad.
trainable_params = filter(lambda p: p.requires_grad, list(backbone.parameters()) + list(decoder.parameters()))
Training Notes for Students
When labels are scarce
Use SSL pretraining on unlabeled domain images, then fine-tune with the smaller labeled mask dataset.
When masks are noisy
Freeze more of the backbone and train the decoder first, so noisy masks do not destroy general visual features.
When objects are small
Use higher image resolution, multi-scale features, and class-balanced losses if foreground pixels are rare.