Contrastive Learning with SimCLR in PyTorch (original) (raw)

Last Updated : 23 Jul, 2025

SimCLR (Simple Framework for Contrastive Learning of Visual Representations) is a self-supervised learning approach that learns powerful image representations without labeled data. It does so by maximizing agreement between differently augmented views of the same image via a contrastive loss in the latent space.By maximizing the similarity between different augmented views of the same image and minimizing similarity with other images, SimCLR enables models to learn powerful visual representations. Implementing SimCLR in PyTorch allows for flexible experimentation and strong performance on image tasks using only unlabeled data.

Core Ideas of SimCLR

SimCLR in PyTorch: Main Components

**1. Data Augmentation: Define a set of strong augmentations to generate two different views of each image.

2****. Encoder and Projection Head**

**3. Contrastive Loss Implementation: A custom loss function (NT-Xent) computes the contrastive loss for each positive pair in the batch.

**4. Training Loop

PyTorch Implementation

1. Install Libraries

Installs the required PyTorch packages to run the model.

Python `

!pip install torch torchvision pytorch-lightning --quiet

`

**Output

Installation

Install Libraries

Setup: Install Libraries

The standard imports models, datasets, transformations, and helper utilities (like tqdm for progress bars).

Python `

import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset import torchvision.transforms as T from torchvision.datasets import CIFAR10 import torchvision.models as models import numpy as np import random from tqdm import tqdm

`

2. Data Augmentation

The augmentations create two different views of the same image to learn invariant features.

Python `

simclr_transform = T.Compose([ T.RandomResizedCrop(32), T.RandomHorizontalFlip(), T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8), T.RandomGrayscale(p=0.2), T.GaussianBlur(kernel_size=3), T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) ])

`

3. Dataset for Two Views

For each image, return two augmented versions: xi and xj, used as positive pairs in contrastive learning.

Python `

class SimCLRDataset(Dataset): def init(self, base_dataset, transform): self.dataset = base_dataset self.transform = transform

def __getitem__(self, index):
    image, _ = self.dataset[index]
    xi = self.transform(image)
    xj = self.transform(image)
    return xi, xj

def __len__(self):
    return len(self.dataset)

`

4. SimCLR Model = Encoder + Projection Head

class SimCLRModel(nn.Module): def init(self, projection_dim=128): super().init() base_model = models.resnet18(weights=None) num_ftrs = base_model.fc.in_features base_model.fc = nn.Identity() self.encoder = base_model self.projection_head = nn.Sequential( nn.Linear(num_ftrs, 512), nn.ReLU(), nn.Linear(512, projection_dim) )

def forward(self, x):
    h = self.encoder(x)
    z = self.projection_head(h)
    return z

`

5. NT-Xent Loss

def nt_xent_loss(z_i, z_j, temperature=0.5): z = torch.cat([z_i, z_j], dim=0) z = F.normalize(z, dim=1)

similarity = torch.matmul(z, z.T)
N = z_i.shape[0]

mask = (~torch.eye(2*N, dtype=bool)).to(z.device)
sim = similarity / temperature
exp_sim = torch.exp(sim) * mask

positive_sim = torch.exp(F.cosine_similarity(z_i, z_j) / temperature)
positives = torch.cat([positive_sim, positive_sim], dim=0)

denominator = exp_sim.sum(dim=1)
loss = -torch.log(positives / denominator)
return loss.mean()

`

6. Training Loop

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = CIFAR10(root='./data', train=True, download=True) contrastive_dataset = SimCLRDataset(train_dataset, simclr_transform) train_loader = DataLoader(contrastive_dataset, batch_size=256, shuffle=True, num_workers=2)

model = SimCLRModel().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

for epoch in range(10): model.train() total_loss = 0 for x_i, x_j in tqdm(train_loader): x_i, x_j = x_i.to(device), x_j.to(device) z_i = model(x_i) z_j = model(x_j)

    loss = nt_xent_loss(z_i, z_j)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

print(f"Epoch {epoch+1} | Loss: {total_loss / len(train_loader):.4f}")

`

**Output

Training

Training Loop

7. Evaluation (Linear Probe)

for param in model.encoder.parameters(): param.requires_grad = False

classifier = nn.Linear(512, 10).to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

def get_features_and_labels(loader): features, labels = [], [] with torch.no_grad(): for x, y in loader: x = x.to(device) h = model.encoder(x) features.append(h.cpu()) labels.append(y) return torch.cat(features), torch.cat(labels)

`

You can download the complete code from here: Contrastive Learning using SimCLR

Practical Considerations

Training and Evaluation