Generative Adversarial Networks (GANs) in PyTorch (original) (raw)

Last Updated : 9 Oct, 2025

Generative Adversarial Networks (GANs) help models to generate realistic data like images. Using GANs two neural networks the generator and the discriminator are trained together in a competitive setup where the generator creates synthetic images and the discriminator learns to distinguish them from real ones. This adversarial training improves both networks over time which results in high-quality generated images. In this article we will implement GANs using the PyTorch and train a model on the MNIST dataset to generate handwritten digit images.

Why Choose PyTorch for GANs?

PyTorch is used deep learning framework because of its:

Implementation

We will build and train a Generative Adversarial Network (GAN) using PyTorch to generate realistic handwritten digit images from the MNIST dataset. Below are the key steps involved:

Step 1: Importing Necessary Libraries

We will be using fundamental pytorch modules like torch and torch.nn, torch.optim for building and training models. We can also use torchvision to load and preprocess the MNIST dataset and torchvision.transforms for image transformations before feeding the data to the GAN.

Python `

import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np

`

Step 2: Define Generator Function

The generator class inherits from nn.Module and takes a noise vector as input.

class Generator(nn.Module): def init(self, noise_dim): super(Generator, self).init() self.noise_dim = noise_dim self.main = nn.Sequential( nn.Linear(noise_dim, 7 * 7 * 256), nn.ReLU(True), nn.Unflatten(1, (256, 7, 7)), nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1), nn.Tanh() )

def forward(self, x):
    return self.main(x)

`

Step 3: Define Discriminator Function

The discriminator class also inherits from nn.Module.

**nn.Conv2d(64, 128, 5, stride=2, padding=2): Adds another convolutional layer increasing channels from 64 to 128, with same kernel size, stride, and padding.

Python `

class Discriminator(nn.Module): def init(self): super(Discriminator, self).init() self.main = nn.Sequential( nn.Conv2d(1, 64, 5, stride=2, padding=2), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(64), nn.Conv2d(64, 128, 5, stride=2, padding=2), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(128), nn.Flatten(), nn.Linear(7 * 7 * 128, 1) )

def forward(self, x):
    return self.main(x)

`

Step 4: Combining the Generator and Discriminator

Here we create instances of the generator and discriminator. The generator creates fake images from random noise while the discriminator classifies images as real or fake.

Python `

NOISE_DIM = 100

generator = Generator(NOISE_DIM) discriminator = Discriminator()

`

Step 5: Device Configuration

Device configuration allows for efficient training of the GAN models on the available hardware resources.

Python `

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = generator.to(device) discriminator = discriminator.to(device)

`

Step 6: Set Loss Function, Optimizer and Hyperparameters

criterion = nn.BCEWithLogitsLoss()

generator_optimizer = optim.Adam( generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) discriminator_optimizer = optim.Adam( discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

NUM_EPOCHS = 5 BATCH_SIZE = 256

`

Step 7: Prepare DataLoader

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset = torchvision.datasets.MNIST( root='./data', train=True, transform=transform, download=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True)

`

Step 8: Training Process

This training loop iterates over the specified number of epochs, training the GAN by alternating between updating the discriminator and the generator:

  1. **real_labels = torch.ones(real_images.size(0), 1, device=device): Creates labels of ones representing real images for the discriminator.
  2. noise = torch.randn(real_images.size(0), NOISE_DIM, device=device): Generates random noise vectors to feed the generator.
  3. fake_labels = torch.ones(real_images.size(0), 1, device=device): Creates labels of ones to trick the discriminator into thinking generated images are real. Python `

for epoch in range(NUM_EPOCHS): for i, data in enumerate(train_loader): real_images, _ = data real_images = real_images.to(device)

    discriminator_optimizer.zero_grad()
    real_labels = torch.ones(real_images.size(0), 1, device=device)
    real_outputs = discriminator(real_images)
    real_loss = criterion(real_outputs, real_labels)
    real_loss.backward()

    noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
    fake_images = generator(noise)
    fake_labels = torch.zeros(real_images.size(0), 1, device=device)
    fake_outputs = discriminator(fake_images.detach())
    fake_loss = criterion(fake_outputs, fake_labels)
    fake_loss.backward()
    discriminator_optimizer.step()

    generator_optimizer.zero_grad()
    fake_labels = torch.ones(real_images.size(0), 1, device=device)
    fake_outputs = discriminator(fake_images)
    gen_loss = criterion(fake_outputs, fake_labels)
    gen_loss.backward()
    generator_optimizer.step()

    if i % 100 == 0:
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
              f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
              f'Generator Loss: {gen_loss.item():.4f}')

`

Step 9: Visualization

Generate and save sample images from the trained generator at intervals to track progress visually.

  1. **test_noise = torch.randn(16, NOISE_DIM, device=device): Creates a batch of 16 random noise vectors for image generation. Python `

def generate_and_save_images(model, epoch, noise): model.eval() with torch.no_grad(): fake_images = model(noise).cpu() fake_images = fake_images.view(fake_images.size(0), 28, 28)

    fig = plt.figure(figsize=(4, 4))
    for i in range(fake_images.size(0)):
        plt.subplot(4, 4, i + 1)
        plt.imshow(fake_images[i], cmap='gray')
        plt.axis('off')

    plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
    plt.show()

test_noise = torch.randn(16, NOISE_DIM, device=device) generate_and_save_images(generator, NUM_EPOCHS, test_noise)

`

**Output:

gan-training

The output of the image is not clear as the image is trained only for 5 epochs, you can train the image for more number of epochs and fine tune according to your need to get better results.

result-12

Output Image

As GANs continue to grow, their ability to create increasingly realistic and diverse images will open up new possibilities across art, design and data science.