Understanding KL Divergence in PyTorch (original) (raw)

Last Updated : 8 Nov, 2025

**Kullback-Leibler (KL) divergence is a fundamental concept in information theory and statistics, used to measure the difference between two probability distributions. In the context of machine learning, it is often used to compare the predicted probability distribution of a model with the true distribution of the data. PyTorch, a popular deep learning library, provides several ways to compute KL divergence, making it a versatile tool for machine learning practitioners.

Table of Content

What is KL Divergence?

KL divergence quantifies how much one probability distribution diverges from a second, expected probability distribution. Mathematically, it is defined as:

D_{KL}(P \parallel Q) = \sum_{x \in X} P(x) \log \frac{P(x)}{Q(x)}

Where P and Q are two probability distributions over the same variable x. It is important to note that KL divergence is not symmetric, meaning : D_{KL}(P \parallel Q) \neq D_{KL}(Q \parallel P)

Why Use KL Divergence?

KL divergence is widely used for several reasons:

Implementing KL Divergence in PyTorch

PyTorch offers multiple methods to compute KL divergence, each suited for different scenarios. Below, we explore these methods and their applications.

1. Using torch.nn.functional.kl_div

The torch.nn.functional.kl_div function is a low-level method in PyTorch that computes the KL divergence between two tensors. It requires the input tensor to be in log-probability form and the target tensor to be in probability form.

Python `

import torch import torch.nn.functional as F

Define input and target tensors

input = F.log_softmax(torch.tensor([[0.8, 0.15, 0.05]]), dim=1) target = torch.tensor([[0.7, 0.2, 0.1]])

Compute KL divergence

kl_divergence = F.kl_div(input, target, reduction='batchmean') print(kl_divergence)

`

**Output:

tensor(0.0935)

This function allows for different reduction methods, such as 'none', 'sum', 'mean', and 'batchmean', with 'batchmean' being the mathematically correct option for KL divergence.

2. Using torch.nn.KLDivLoss

The torch.nn.KLDivLoss class provides a higher-level interface for computing KL divergence loss. It is similar to torch.nn.functional.kl_div but is used as a loss function in training neural networks.

Python `

import torch import torch.nn as nn import torch.nn.functional as F

Define input and target tensors

input = F.log_softmax(torch.tensor([[0.8, 0.15, 0.05]]), dim=1) target = torch.tensor([[0.7, 0.2, 0.1]])

Initialize KLDivLoss

criterion = nn.KLDivLoss(reduction='batchmean')

Compute loss

loss = criterion(input, target) print(loss)

`

**Output:

tensor(0.0935)

This loss function is particularly useful in scenarios where you need to compare the output distribution of a model with a target distribution during training.

3. Using torch.distributions.kl.kl_divergence

For more complex probability distributions, PyTorch provides torch.distributions.kl.kl_divergence, which can compute KL divergence between two distribution objects. This method is particularly useful when dealing with distributions beyond simple tensors, such as Gaussian distributions.

Python `

import torch from torch.distributions import Normal from torch.distributions.kl import kl_divergence

Define two Gaussian distributions

p = Normal(torch.tensor([0.0]), torch.tensor([1.0])) q = Normal(torch.tensor([1.0]), torch.tensor([1.5]))

Compute KL divergence

kl_div = kl_divergence(p, q) print(kl_div)

`

**Output:

tensor([0.3499])

This function requires the distributions to be registered with PyTorch, allowing for a more intuitive and flexible way to compute KL divergence for various distribution types.

Practical Example: Minimizing KL Divergence in PyTorch

Let’s create a simple example where we minimize KL divergence between two probability distributions in PyTorch:

Python `

import torch import torch.nn.functional as F

P = torch.tensor([0.2, 0.5, 0.3], requires_grad=True) Q = torch.tensor([0.1, 0.7, 0.2])

optimizer = optim.Adam([P], lr=0.01)

for _ in range(100): optimizer.zero_grad()

P_soft = torch.softmax(P, dim=0)

kl_loss = F.kl_div(Q.log(), P_soft, reduction='sum')
kl_loss.backward()

optimizer.step()

print(kl_loss.item())

`

**Output:

KL Loss: 0.22997523844242096
KL Loss: 0.22365564107894897
KL Loss: 0.21740710735321045
KL Loss: 0.2112329602241516
KL Loss: 0.20513615012168884
KL Loss: 0.19912001490592957
.
.
KL Loss: 0.00514531135559082
KL Loss: 0.004947632551193237
KL Loss: 0.004757806658744812

In this example, we use an optimizer to minimize the KL divergence between two distributions. By updating the distribution P, we aim to bring it closer to Q through gradient descent.

Applications of KL Divergence

KL divergence is widely used in machine learning for various purposes, including:

Challenges and Considerations

While KL divergence is a powerful tool, it comes with certain challenges:

Conclusion

KL divergence is an essential concept in machine learning, providing a measure of how one probability distribution diverges from another. PyTorch offers robust tools for computing KL divergence, making it accessible for various applications in deep learning and beyond. By understanding the different methods available in PyTorch and their appropriate use cases, practitioners can effectively leverage KL divergence in their models. Whether used for model training, distribution comparison, or probabilistic inference, KL divergence remains a cornerstone of modern machine learning techniques.