torch.nn.utils.weight_norm — PyTorch 2.7 documentation (original) (raw)
torch.nn.utils.weight_norm(module, name='weight', dim=0)[source][source]¶
Apply weight normalization to a parameter in the given module.
w=gv∥v∥\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by name
(e.g. 'weight'
) with two parameters: one specifying the magnitude (e.g. 'weight_g'
) and one specifying the direction (e.g. 'weight_v'
). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every forward()
call.
By default, with dim=0
, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, usedim=None
.
See https://arxiv.org/abs/1602.07868
Parameters
- module (Module) – containing module
- name (str, optional) – name of weight parameter
- dim (int, optional) – dimension over which to compute the norm
Returns
The original module with the weight norm hook
Return type
T_module
Example:
m = weight_norm(nn.Linear(20, 40), name='weight') m Linear(in_features=20, out_features=40, bias=True) m.weight_g.size() torch.Size([40, 1]) m.weight_v.size() torch.Size([40, 20])