torch.nn.utils.spectral_norm — PyTorch 2.7 documentation (original) (raw)

torch.nn.utils.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source][source]

Apply spectral normalization to a parameter in the given module.

WSN=Wσ(W),σ(W)=max⁡h:h≠0∥Wh∥2∥h∥2\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm σ\sigma of the weight matrix calculated using power iteration method. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D in power iteration method to get spectral norm. This is implemented via a hook that calculates spectral norm and rescales weight before every forward() call.

See Spectral Normalization for Generative Adversarial Networks .

Parameters

Returns

The original module with the spectral norm hook

Return type

T_module

Example:

m = spectral_norm(nn.Linear(20, 40)) m Linear(in_features=20, out_features=40, bias=True) m.weight_u.size() torch.Size([40])