SimCLR Training Objective
Get introduced to SimCLR’s network architecture and its loss function.
We'll cover the following...
Now that we have two augmented versions of the input batch,
Network architecture
As shown in the figure below, the two augmented versions of an image,
The code example below implements the class SimCLR_Network
that passes the input image to a resnet18
backbone (
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torchvision.transforms as Timport torchvisionimport torchvision.models as modelsfrom utils import Augmentfrom PIL import Imageclass SimCLR_Network(nn.Module):def __init__(self, embed_dim=512):super(SimCLR_Network, self).__init__()self.backbone = models.resnet18() # resnet18 backbonein_features = self.backbone.fc.in_featuresself.backbone.fc = nn.Identity() # remove the fc layer of resnet18# add mlp projection headself.projection = nn.Sequential(nn.Linear(in_features, embed_dim),nn.BatchNorm1d(embed_dim),nn.ReLU(),nn.Linear(in_features=embed_dim, out_features=embed_dim),nn.BatchNorm1d(embed_dim),)def forward(self, x):f = self.backbone(x)return self.projection(f)network = SimCLR_Network()batch = [Image.open("n02107683_Bernese_mountain_dog.jpeg"), Image.open("cat.jpg")]batch = [T.functional.to_tensor(img.resize((224, 224))) for img in batch]batch = torch.stack(batch)torchvision.utils.save_image(batch, "./output/image.png", normalize=True)augment = Augment(img_size=224)aug1, aug2 = augment(batch), augment(batch) # generate two augmented versions of batchtorchvision.utils.save_image(aug1, "./output/t1_image.png", normalize=True)torchvision.utils.save_image(aug2, "./output/t2_image.png", normalize=True)z1, z2 = network(aug1), network(aug2) # feature embeddingsprint("Shape of z1 and z2 is", z1.shape, "and ", z2.shape)
Line 10: We implement the class
SimCLR_Network
that passes the input image to aresnet18
backbone () and an MLP projection head ( ). Line 13: We define the feature backbone
self.backbone
as aresnet18
network.Lines 14–15: We remove the fully connected classification layer
resnet18
by reinitializing it as annn.Identity()
layer. Theself.backbone
takes an image () and returns —a dimensional features vector. Lines 18–24: We define the projection head
self.projection
as an MLP layer using thenn.Linear
,nn.ReLU
andnn.BatchNorm1d
layers. This projection layer takesresnet18
's-dimensional features from self.backbone
and ...