torch.renorm — PyTorch 2.7 documentation (original) (raw)
torch.renorm(input, p, dim, maxnorm, *, out=None) → Tensor¶
Returns a tensor where each sub-tensor of input
along dimensiondim
is normalized such that the p-norm of the sub-tensor is lower than the value maxnorm
Note
If the norm of a row is lower than maxnorm, the row is unchanged
Parameters
- input (Tensor) – the input tensor.
- p (float) – the power for the norm computation
- dim (int) – the dimension to slice over to get the sub-tensors
- maxnorm (float) – the maximum norm to keep each sub-tensor under
Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
x = torch.ones(3, 3) x[1].fill_(2) tensor([ 2., 2., 2.]) x[2].fill_(3) tensor([ 3., 3., 3.]) x tensor([[ 1., 1., 1.], [ 2., 2., 2.], [ 3., 3., 3.]]) torch.renorm(x, 1, 0, 5) tensor([[ 1.0000, 1.0000, 1.0000], [ 1.6667, 1.6667, 1.6667], [ 1.6667, 1.6667, 1.6667]])