Implementing an Autoencoder in PyTorch (original) (raw)

Last Updated : 11 Mar, 2025

Autoencoders are neural networks that learn to compress and reconstruct data. In this guide we’ll walk you through building a simple autoencoder in PyTorch using the MNIST dataset. This approach is useful for image compression, denoising and feature extraction.

**Implementation of Autoencoder in PyTorch

**Step 1: Importing Modules and Load the Dataset

We will use the torch.optim and the torch.nn module from the torch package and datasets & transforms from torchvision package. After importing loads the MNIST dataset into loader using DataLoader module. MNIST dataset is a collection of 70,000 grayscale images of handwritten digits (0-9). Each image is 28×28 pixels. It is widely used in deep learning for tasks like classification and feature extraction.

Python `

import torch from torch import nn, optim from torchvision import datasets, transforms import matplotlib.pyplot as plt

tensor_transform = transforms.ToTensor() dataset = datasets.MNIST(root="./data", train=True, download=True, transform=tensor_transform) loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)

`

**Step 2: Create Autoencoder Class

In this step we are going to define our autoencoder .It consists of two key components:

9 ==> 18 ==> 36 ==> 64 ==> 128 ==> 784 ==> 28*28 = 784

Python `

class AE(nn.Module): def init(self): super(AE, self).init() self.encoder = nn.Sequential( nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 36), nn.ReLU(), nn.Linear(36, 18), nn.ReLU(), nn.Linear(18, 9) ) self.decoder = nn.Sequential( nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 36), nn.ReLU(), nn.Linear(36, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 28 * 28), nn.Sigmoid() )

def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

`

**Step 3: Initializing Model

Now we validate the model using the Mean Squared Error function and we use an Adam Optimizer with a learning rate of 0.1 and weight decay of [Tex]10^{-8}[/Tex]

Python `

model = AE() loss_function = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)

`

**Step 4: Train the model and Plot Training Loss

In this step the model undergoes training for 20 epochs using the **Mean Squared Error (MSE) loss function and the **Adam optimizer. The training process updates the model’s weights using backpropagation and optimization techniques. Loss values are recorded during each iteration and after training a loss plot is generated to assess the model’s performance over time.

**Note: This snippet takes 15 to 20 mins to execute depending on the processor type. Initialize epoch = 1 for quick results. Use a GPU/TPU runtime for faster computations.

Python `

epochs = 20 outputs = [] losses = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)

for epoch in range(epochs): for images, _ in loader: images = images.view(-1, 28 * 28).to(device)

    reconstructed = model(images)
    loss = loss_function(reconstructed, images)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())

outputs.append((epoch, images, reconstructed))
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}")

plt.style.use('fivethirtyeight') plt.figure(figsize=(8, 5)) plt.plot(losses, label='Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.legend() plt.show()

`

**Output:

The loss curve in the image shows how the model’s error decreases over training iterations. Initially the loss is high but quickly drops indicating that the model is learning.

**Step 5: Visualize Input and reconstructed Images

This step focuses on evaluating the performance of the trained autoencoder. By comparing the original MNIST images with their reconstructed versions we can assess how well the model has learned to encode and decode the data.

Python `

model.eval() dataiter = iter(loader) images, _ = next(dataiter)

images = images.view(-1, 28 * 28).to(device) reconstructed = model(images)

fig, axes = plt.subplots(nrows=2, ncols=10, figsize=(10, 3)) for i in range(10): axes[0, i].imshow(images[i].cpu().detach().numpy().reshape(28, 28), cmap='gray') axes[0, i].axis('off') axes[1, i].imshow(reconstructed[i].cpu().detach().numpy().reshape(28, 28), cmap='gray') axes[1, i].axis('off') plt.show()

`

**Output:

The above image compares original MNIST digits (top row) with their reconstructed versions (bottom row) generated by an autoencoder. While some reconstructions closely resemble the original images, others appear blurry or distorted indicating information loss during compression. This can be improved by using a deeper network a convolutional autoencoder or tuning hyperparameters like learning rate and latent space