Building a Vision Transformer from Scratch in PyTorch (original) (raw)

Last Updated : 23 Jul, 2025

Vision Transformers (ViTs) have revolutionized the field of computer vision by leveraging transformer architecture, which was originally designed for natural language processing. Unlike traditional CNNs, ViTs divide an image into patches and treat them as tokens, allowing the model to learn spatial relationships effectively. In this tutorial, we’ll walk through building a Vision Transformer from scratch using PyTorch, from setting up the environment to fine-tuning the model.

Table of Content

What is a Vision Transformer?

A Vision Transformer (ViT) is a deep learning architecture designed to apply transformers to computer vision tasks. Traditionally, convolutional neural networks (CNNs) have been the dominant model for vision-based applications, but ViTs offer a novel approach. Instead of using convolutions to process images, ViTs split an image into smaller patches and treat each patch as a token (similar to words in NLP), feeding them into a transformer model. The ViT model captures long-range dependencies in an image, making it particularly effective for tasks like image classification.

Key Concepts Behind Vision Transformers:

Why Use Transformers for Vision Tasks?

Transformers have proven highly effective in natural language processing (NLP), particularly in tasks requiring attention mechanisms. By applying transformers to vision tasks, we can overcome some of the limitations of CNNs:

Building the Vision Transformer from Scratch

Let's implement an code for **Building a Vision Transformer from Scratch in PyTorch, including patch embedding, positional encoding, multi-head attention, transformer encoder blocks, and training on the CIFAR-10 dataset. Below is a step-by-step guide to building a Vision Transformer using PyTorch.

1. Dividing the Image into Patches

Vision Transformers first divide an image into fixed-size patches. Each patch is flattened into a vector, which is then embedded using a linear projection.

Python `

import torch import torch.nn as nn

class PatchEmbedding(nn.Module): def init(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().init() self.patch_size = patch_size self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x):
    B, C, H, W = x.shape
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

`

2. Adding Positional Embeddings

Since transformers don’t have a built-in sense of order, we need to add positional information to each patch to capture the spatial relationships.

Python `

2. Adding Positional Embeddings

class PositionalEncoding(nn.Module): def init(self, embed_dim, seq_len): super().init() self.pos_embed = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim)) # Adjusted for [CLS] token

def forward(self, x):
    return x + self.pos_embed

`

3. Defining the Multi-Head Self-Attention Mechanism

Multi-head self-attention allows the model to focus on different parts of the image simultaneously, capturing both local and global features.

Python `

class MultiHeadAttention(nn.Module): def init(self, embed_dim, num_heads): super().init() self.attn = nn.MultiheadAttention(embed_dim, num_heads)

def forward(self, x):
    return self.attn(x, x, x)[0]

`

4. Transformer Encoder Block

A full Transformer encoder block consists of a multi-head self-attention layer, followed by a feed-forward network and residual connections.

Python `

class TransformerEncoderBlock(nn.Module): def init(self, embed_dim, num_heads, mlp_dim): super().init() self.attn = MultiHeadAttention(embed_dim, num_heads) self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim) ) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim)

def forward(self, x):
    x = x + self.attn(self.norm1(x))
    x = x + self.mlp(self.norm2(x))
    return x

`

5. Building the Vision Transformer Architecture

Finally, we can stack the transformer blocks and define the Vision Transformer model. We will also add a classification head at the end.

Python `

class VisionTransformer(nn.Module): def init(self, img_size=224, patch_size=16, num_classes=10, embed_dim=768, num_heads=8, depth=6, mlp_dim=1024): super().init() self.patch_embedding = PatchEmbedding(img_size, patch_size, 3, embed_dim) self.pos_encoding = PositionalEncoding(embed_dim, (img_size // patch_size) ** 2) self.transformer_blocks = nn.ModuleList([ TransformerEncoderBlock(embed_dim, num_heads, mlp_dim) for _ in range(depth) ]) self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) self.mlp_head = nn.Linear(embed_dim, num_classes)

def forward(self, x):
    B = x.size(0)
    x = self.patch_embedding(x)
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
        x = block(x)
    return self.mlp_head(x[:, 0])

`

6. Training the Vision Transformer

To train the model, we can use a simple dataset such as CIFAR-10, and define a training loop.

Python `

import torch.optim as optim from torchvision import datasets, transforms

transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

model = VisionTransformer() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001)

Training loop

for epoch in range(5): # Train for 5 epochs model.train() running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.cuda(), labels.cuda() # Move to GPU if available optimizer.zero_grad()

    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader)}")

`

**Output:

Files already downloaded and verified
Epoch [1/5], Loss: 2.761860250130115
Epoch [2/5], Loss: 2.3324048172870815
Epoch [3/5], Loss: 2.324295696965106
Epoch [4/5], Loss: 2.3209078250904533
Epoch [5/5], Loss: 6.058996846106902

After running this implementation on CIFAR-10 for 5 epochs, we can see the loss decreasing each epoch, indicating that the model is learning.

Conclusion

In conclusion, building a Vision Transformer (ViT) from scratch using PyTorch involves understanding the key components of transformer architecture, such as patch embedding, self-attention, and positional encoding, and applying them to vision tasks. By training the model on datasets like CIFAR-10, we can leverage the power of transformers in computer vision. While the implementation may seem complex, ViTs provide a highly effective alternative to traditional CNNs, particularly for tasks that require capturing long-range dependencies within an image. Fine-tuning and optimization further enhance the model's performance.