torch.renorm — PyTorch 2.7 documentation (original) (raw)

torch.renorm(input, p, dim, maxnorm, *, out=None) → Tensor

Returns a tensor where each sub-tensor of input along dimensiondim is normalized such that the p-norm of the sub-tensor is lower than the value maxnorm

Note

If the norm of a row is lower than maxnorm, the row is unchanged

Parameters

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

x = torch.ones(3, 3) x[1].fill_(2) tensor([ 2., 2., 2.]) x[2].fill_(3) tensor([ 3., 3., 3.]) x tensor([[ 1., 1., 1.], [ 2., 2., 2.], [ 3., 3., 3.]]) torch.renorm(x, 1, 0, 5) tensor([[ 1.0000, 1.0000, 1.0000], [ 1.6667, 1.6667, 1.6667], [ 1.6667, 1.6667, 1.6667]])