SimCLR Training Objective

Get introduced to SimCLR’s network architecture and its loss function.

Now that we have two augmented versions of the input batch, T1(B)T_1(B) and T2(B)T_2(B), we'll look into other components of the SimCLR training pipeline.

Network architecture

As shown in the figure below, the two augmented versions of an image, XiX_i (i.e., T1(Xi)T_1(X_i) and T2(Xi)T_2(X_i)), are passed through the neural network f(.)f(.) to get the penultimate feature representations, hi1h_{i1}, and hi2h_{i2}, respectively. These feature representations are passed again through a multilayer perceptron (MLP) projection head g(.)g(.) to get the feature embeddings zi1z_{i1} and zi2z_{i2}, respectively.

Press + to interact
Illustration of how SimCLR works
Illustration of how SimCLR works

The code example below implements the class SimCLR_Network that passes the input image to a resnet18 backbone (ff) and an MLP projection head (gg).

Press + to interact
main.py
utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision
import torchvision.models as models
from utils import Augment
from PIL import Image
class SimCLR_Network(nn.Module):
def __init__(self, embed_dim=512):
super(SimCLR_Network, self).__init__()
self.backbone = models.resnet18() # resnet18 backbone
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity() # remove the fc layer of resnet18
# add mlp projection head
self.projection = nn.Sequential(
nn.Linear(in_features, embed_dim),
nn.BatchNorm1d(embed_dim),
nn.ReLU(),
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(embed_dim),
)
def forward(self, x):
f = self.backbone(x)
return self.projection(f)
network = SimCLR_Network()
batch = [Image.open("n02107683_Bernese_mountain_dog.jpeg"), Image.open("cat.jpg")]
batch = [T.functional.to_tensor(img.resize((224, 224))) for img in batch]
batch = torch.stack(batch)
torchvision.utils.save_image(batch, "./output/image.png", normalize=True)
augment = Augment(img_size=224)
aug1, aug2 = augment(batch), augment(batch) # generate two augmented versions of batch
torchvision.utils.save_image(aug1, "./output/t1_image.png", normalize=True)
torchvision.utils.save_image(aug2, "./output/t2_image.png", normalize=True)
z1, z2 = network(aug1), network(aug2) # feature embeddings
print("Shape of z1 and z2 is", z1.shape, "and ", z2.shape)
  1. Line 10: We implement the class SimCLR_Network that passes the input image to a resnet18 backbone (ff) and an MLP projection head (gg).

  2. Line 13: We define the feature backbone self.backbone as a resnet18 network.

  3. Lines 14–15: We remove the fully connected classification layer resnet18 by reinitializing it as an nn.Identity() layer. The self.backbone takes an image (batch size×3×224×224\text{batch size} \times 3 \times 224 \times 224) and returns 512512—a dimensional features vector.

  4. Lines 18–24: We define the projection head self.projection as an MLP layer using the nn.Linear, nn.ReLU and nn.BatchNorm1d layers. This projection layer takes resnet18's 512512-dimensional features from self.backbone and ...