CSE 438 Digital Image Processing

Self-Supervised Learning + Semantic Segmentation

This page explains how an SSL pretrained encoder becomes the backbone of a dense prediction model, reducing label demand and improving transfer when segmentation masks are limited.

Unlabeled images
to
SSL backbone
to
Segmentation decoder
to
Pixel mask
Stage 1

Pretrain encoder

DINO, MAE, JEPA, or BYOL learns visual features from unlabeled images using crops, masks, targets, or teacher-student matching.

Stage 2

Add decoder

A segmentation decoder upsamples encoder features and produces class logits for every pixel.

Stage 3

Fine-tune masks

Supervised training uses labeled masks. The backbone can be frozen first, then unfrozen with a smaller learning rate.

import torch
import torch.nn as nn
from torchvision.models import vit_b_16

class SSLViTBackbone(nn.Module):
    def __init__(self, checkpoint_path):
        super().__init__()
        self.encoder = vit_b_16(weights=None)
        self.encoder.heads = nn.Identity()

        # Load weights learned from an SSL method such as MAE or DINO.
        # strict=False is useful when the checkpoint also contains a projector.
        state = torch.load(checkpoint_path, map_location="cpu")
        self.encoder.load_state_dict(state, strict=False)

    def forward(self, x):
        # A real dense ViT implementation should return patch tokens.
        # This simplified example returns a global feature vector.
        return self.encoder(x)
When labels are scarce Use SSL pretraining on unlabeled domain images, then fine-tune with the smaller labeled mask dataset.
When masks are noisy Freeze more of the backbone and train the decoder first, so noisy masks do not destroy general visual features.
When objects are small Use higher image resolution, multi-scale features, and class-balanced losses if foreground pixels are rare.