KLDivLoss — PyTorch 2.7 documentation (original) (raw)

class torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)[source][source]

The Kullback-Leibler divergence loss.

For tensors of the same shape ypred, ytruey_{\text{pred}},\ y_{\text{true}}, where ypredy_{\text{pred}} is the input and ytruey_{\text{true}} is thetarget, we define the pointwise KL-divergence as

L(ypred, ytrue)=ytrue⋅log⁡ytrueypred=ytrue⋅(log⁡ytrue−log⁡ypred)L(y_{\text{pred}},\ y_{\text{true}}) = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})

To avoid underflow issues when computing this quantity, this loss expects the argumentinput in the log-space. The argument target may also be provided in the log-space if log_target= True.

To summarise, this function is roughly equivalent to computing

if not log_target: # default loss_pointwise = target * (target.log() - input) else: loss_pointwise = target.exp() * (target - input)

and then reducing this result depending on the argument reduction as

if reduction == "mean": # default loss = loss_pointwise.mean() elif reduction == "batchmean": # mathematically correct loss = loss_pointwise.sum() / input.size(0) elif reduction == "sum": loss = loss_pointwise.sum() else: # reduction == "none" loss = loss_pointwise

Note

As all the other losses in PyTorch, this function expects the first argument,input, to be the output of the model (e.g. the neural network) and the second, target, to be the observations in the dataset. This differs from the standard mathematical notation KL(P ∣∣ Q)KL(P\ ||\ Q) wherePP denotes the distribution of the observations and QQ denotes the model.

Warning

reduction= “mean” doesn’t return the true KL divergence value, please usereduction= “batchmean” which aligns with the mathematical definition.

Parameters

Shape:

Examples::

kl_loss = nn.KLDivLoss(reduction="batchmean")

input should be a distribution in the log space

input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)

Sample a batch of distributions. Usually this would come from the dataset

target = F.softmax(torch.rand(3, 5), dim=1) output = kl_loss(input, target)

kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True) log_target = F.log_softmax(torch.rand(3, 5), dim=1) output = kl_loss(input, log_target)