>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html """ linear_fns = [ "linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", ] if nonlinearity in linear_fns or nonlinearity == "sigmoid": return 1 elif nonlinearity == "tanh": return 5.0 / 3 elif nonlinearity == "relu": return math.sqrt(2.0) elif nonlinearity == "leaky_relu": if param is None: negative_slope = 0.01 elif ( not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float) ): # True/False are instances of int, hence check above negative_slope = param else: raise ValueError(f"negative_slope {param} not a valid number") return math.sqrt(2.0 / (1 + negative_slope**2)) elif nonlinearity == "selu": return ( 3.0 / 4 ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) else: raise ValueError(f"Unsupported nonlinearity {nonlinearity}")">

torch.nn.init — PyTorch 2.7 documentation (original) (raw)

mypy: allow-untyped-defs

"""This file contains utilities for initializing neural network parameters.""" import math import warnings from typing import Optional as _Optional

import torch from torch import Tensor

These no_grad_* functions are necessary as wrappers around the parts of these

functions that use with torch.no_grad(). The JIT doesn't support context

managers, so these need to be implemented as builtins. Using these wrappers

lets us keep those builtins small and re-usable.

def no_grad_uniform(tensor, a, b, generator=None): with torch.no_grad(): return tensor.uniform_(a, b, generator=generator)

def no_grad_normal(tensor, mean, std, generator=None): with torch.no_grad(): return tensor.normal_(mean, std, generator=generator)

def no_grad_trunc_normal(tensor, mean, std, a, b, generator=None): # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

if (mean < a - 2 * std) or (mean > b + 2 * std):
    warnings.warn(
        "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
        "The distribution of values may be incorrect.",
        stacklevel=2,
    )

with torch.no_grad():
    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.0))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor

def no_grad_fill(tensor, val): with torch.no_grad(): return tensor.fill_(val)

def no_grad_zero(tensor): with torch.no_grad(): return tensor.zero_()

[docs]def calculate_gain(nonlinearity, param=None): r"""Return the recommended gain value for the given nonlinearity function.

The values are as follows:

================= ====================================================
nonlinearity      gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D      :math:`1`
Sigmoid           :math:`1`
Tanh              :math:`\frac{5}{3}`
ReLU              :math:`\sqrt{2}`
Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
SELU              :math:`\frac{3}{4}`
================= ====================================================

.. warning::
    In order to implement `Self-Normalizing Neural Networks`_ ,
    you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
    This gives the initial weights a variance of ``1 / N``,
    which is necessary to induce a stable fixed point in the forward pass.
    In contrast, the default gain for ``SELU`` sacrifices the normalization
    effect for more stable gradient flow in rectangular layers.

Args:
    nonlinearity: the non-linear function (`nn.functional` name)
    param: optional parameter for the non-linear function

Examples:
    >>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2

.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
"""
linear_fns = [
    "linear",
    "conv1d",
    "conv2d",
    "conv3d",
    "conv_transpose1d",
    "conv_transpose2d",
    "conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
    return 1
elif nonlinearity == "tanh":
    return 5.0 / 3
elif nonlinearity == "relu":
    return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
    if param is None:
        negative_slope = 0.01
    elif (
        not isinstance(param, bool)
        and isinstance(param, int)
        or isinstance(param, float)
    ):
        # True/False are instances of int, hence check above
        negative_slope = param
    else:
        raise ValueError(f"negative_slope {param} not a valid number")
    return math.sqrt(2.0 / (1 + negative_slope**2))
elif nonlinearity == "selu":
    return (
        3.0 / 4
    )  # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
    raise ValueError(f"Unsupported nonlinearity {nonlinearity}")

[docs]def uniform_( tensor: Tensor, a: float = 0.0, b: float = 1.0, generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the uniform distribution.

:math:`\mathcal{U}(a, b)`.

Args:
    tensor: an n-dimensional `torch.Tensor`
    a: the lower bound of the uniform distribution
    b: the upper bound of the uniform distribution
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.uniform_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
    return torch.overrides.handle_torch_function(
        uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
    )
return _no_grad_uniform_(tensor, a, b, generator)

[docs]def normal_( tensor: Tensor, mean: float = 0.0, std: float = 1.0, generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the normal distribution.

:math:`\mathcal{N}(\text{mean}, \text{std}^2)`.

Args:
    tensor: an n-dimensional `torch.Tensor`
    mean: the mean of the normal distribution
    std: the standard deviation of the normal distribution
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.normal_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
    return torch.overrides.handle_torch_function(
        normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
    )
return _no_grad_normal_(tensor, mean, std, generator)

[docs]def trunc_normal_( tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from a truncated normal distribution.

The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.

Args:
    tensor: an n-dimensional `torch.Tensor`
    mean: the mean of the normal distribution
    std: the standard deviation of the normal distribution
    a: the minimum cutoff value
    b: the maximum cutoff value
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)

[docs]def constant_(tensor: Tensor, val: float) -> Tensor: r"""Fill the input Tensor with the value :math:\text{val}.

Args:
    tensor: an n-dimensional `torch.Tensor`
    val: the value to fill the tensor with

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.constant_(w, 0.3)
"""
if torch.overrides.has_torch_function_variadic(tensor):
    return torch.overrides.handle_torch_function(
        constant_, (tensor,), tensor=tensor, val=val
    )
return _no_grad_fill_(tensor, val)

[docs]def ones_(tensor: Tensor) -> Tensor: r"""Fill the input Tensor with the scalar value 1.

Args:
    tensor: an n-dimensional `torch.Tensor`

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.ones_(w)
"""
return _no_grad_fill_(tensor, 1.0)

[docs]def zeros_(tensor: Tensor) -> Tensor: r"""Fill the input Tensor with the scalar value 0.

Args:
    tensor: an n-dimensional `torch.Tensor`

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.zeros_(w)
"""
return _no_grad_zero_(tensor)

[docs]def eye_(tensor): r"""Fill the 2-dimensional input Tensor with the identity matrix.

Preserves the identity of the inputs in `Linear` layers, where as
many inputs are preserved as possible.

Args:
    tensor: a 2-dimensional `torch.Tensor`

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.eye_(w)
"""
if tensor.ndimension() != 2:
    raise ValueError("Only tensors with 2 dimensions are supported")

with torch.no_grad():
    torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
return tensor

[docs]def dirac_(tensor, groups=1): r"""Fill the {3, 4, 5}-dimensional input Tensor with the Dirac delta function.

Preserves the identity of the inputs in `Convolutional`
layers, where as many input channels are preserved as possible. In case
of groups>1, each group of channels preserves identity

Args:
    tensor: a {3, 4, 5}-dimensional `torch.Tensor`
    groups (int, optional): number of groups in the conv layer (default: 1)
Examples:
    >>> w = torch.empty(3, 16, 5, 5)
    >>> nn.init.dirac_(w)
    >>> w = torch.empty(3, 24, 5, 5)
    >>> nn.init.dirac_(w, 3)
"""
dimensions = tensor.ndimension()
if dimensions not in [3, 4, 5]:
    raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")

sizes = tensor.size()

if sizes[0] % groups != 0:
    raise ValueError("dim 0 must be divisible by groups")

out_chans_per_grp = sizes[0] // groups
min_dim = min(out_chans_per_grp, sizes[1])

with torch.no_grad():
    tensor.zero_()

    for g in range(groups):
        for d in range(min_dim):
            if dimensions == 3:  # Temporal convolution
                tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
            elif dimensions == 4:  # Spatial convolution
                tensor[
                    g * out_chans_per_grp + d,
                    d,
                    tensor.size(2) // 2,
                    tensor.size(3) // 2,
                ] = 1
            else:  # Volumetric convolution
                tensor[
                    g * out_chans_per_grp + d,
                    d,
                    tensor.size(2) // 2,
                    tensor.size(3) // 2,
                    tensor.size(4) // 2,
                ] = 1
return tensor

def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.dim() if dimensions < 2: raise ValueError( "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" )

num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
    # math.prod is not always available, accumulate the product manually
    # we could use functools.reduce but that is not supported by TorchScript
    for s in tensor.shape[2:]:
        receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size

return fan_in, fan_out

[docs]def xavier_uniform_( tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values using a Xavier uniform distribution.

The method is described in `Understanding the difficulty of training
deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where

.. math::
    a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}

Also known as Glorot initialization.

Args:
    tensor: an n-dimensional `torch.Tensor`
    gain: an optional scaling factor
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

Note:
    Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
    that the weight matrix is used in a transposed manner,
    (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
    This is important for correct initialization.
    If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
    pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``.
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation

return _no_grad_uniform_(tensor, -a, a, generator)

[docs]def xavier_normal_( tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values using a Xavier normal distribution.

The method is described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where

.. math::
    \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}

Also known as Glorot initialization.

Args:
    tensor: an n-dimensional `torch.Tensor`
    gain: an optional scaling factor
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.xavier_normal_(w)

Note:
    Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
    that the weight matrix is used in a transposed manner,
    (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
    This is important for correct initialization.
    If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
    pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``.
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))

return _no_grad_normal_(tensor, 0.0, std, generator)

def _calculate_correct_fan(tensor, mode): mode = mode.lower() valid_modes = ["fan_in", "fan_out"] if mode not in valid_modes: raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")

fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out

[docs]def kaiming_uniform_( tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu", generator: _Optional[torch.Generator] = None, ): r"""Fill the input Tensor with values using a Kaiming uniform distribution.

The method is described in `Delving deep into rectifiers: Surpassing
human-level performance on ImageNet classification` - He, K. et al. (2015).
The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where

.. math::
    \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}

Also known as He initialization.

Args:
    tensor: an n-dimensional `torch.Tensor`
    a: the negative slope of the rectifier used after this layer (only
        used with ``'leaky_relu'``)
    mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
        preserves the magnitude of the variance of the weights in the
        forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
        backwards pass.
    nonlinearity: the non-linear function (`nn.functional` name),
        recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

Note:
    Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
    that the weight matrix is used in a transposed manner,
    (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
    This is important for correct initialization.
    If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
    pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``.
"""
if torch.overrides.has_torch_function_variadic(tensor):
    return torch.overrides.handle_torch_function(
        kaiming_uniform_,
        (tensor,),
        tensor=tensor,
        a=a,
        mode=mode,
        nonlinearity=nonlinearity,
        generator=generator,
    )

if 0 in tensor.shape:
    warnings.warn("Initializing zero-element tensors is a no-op")
    return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
with torch.no_grad():
    return tensor.uniform_(-bound, bound, generator=generator)

[docs]def kaiming_normal_( tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu", generator: _Optional[torch.Generator] = None, ): r"""Fill the input Tensor with values using a Kaiming normal distribution.

The method is described in `Delving deep into rectifiers: Surpassing
human-level performance on ImageNet classification` - He, K. et al. (2015).
The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where

.. math::
    \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}

Also known as He initialization.

Args:
    tensor: an n-dimensional `torch.Tensor`
    a: the negative slope of the rectifier used after this layer (only
        used with ``'leaky_relu'``)
    mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
        preserves the magnitude of the variance of the weights in the
        forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
        backwards pass.
    nonlinearity: the non-linear function (`nn.functional` name),
        recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

Note:
    Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
    that the weight matrix is used in a transposed manner,
    (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
    This is important for correct initialization.
    If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
    pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``.
"""
if 0 in tensor.shape:
    warnings.warn("Initializing zero-element tensors is a no-op")
    return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
    return tensor.normal_(0, std, generator=generator)

[docs]def orthogonal_( tensor, gain=1, generator: _Optional[torch.Generator] = None, ): r"""Fill the input Tensor with a (semi) orthogonal matrix.

Described in `Exact solutions to the nonlinear dynamics of learning in deep
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
at least 2 dimensions, and for tensors with more than 2 dimensions the
trailing dimensions are flattened.

Args:
    tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
    gain: optional scaling factor
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
    >>> w = torch.empty(3, 5)
    >>> nn.init.orthogonal_(w)
"""
if tensor.ndimension() < 2:
    raise ValueError("Only tensors with 2 or more dimensions are supported")

if tensor.numel() == 0:
    # no-op
    return tensor
rows = tensor.size(0)
cols = tensor.numel() // rows
flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator)

if rows < cols:
    flattened.t_()

# Compute the qr factorization
q, r = torch.linalg.qr(flattened)
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
d = torch.diag(r, 0)
ph = d.sign()
q *= ph

if rows < cols:
    q.t_()

with torch.no_grad():
    tensor.view_as(q).copy_(q)
    tensor.mul_(gain)
return tensor

[docs]def sparse_( tensor, sparsity, std=0.01, generator: _Optional[torch.Generator] = None, ): r"""Fill the 2D input Tensor as a sparse matrix.

The non-zero elements will be drawn from the normal distribution
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
Hessian-free optimization` - Martens, J. (2010).

Args:
    tensor: an n-dimensional `torch.Tensor`
    sparsity: The fraction of elements in each column to be set to zero
    std: the standard deviation of the normal distribution used to generate
        the non-zero values
    generator: the torch Generator to sample from (default: None)

Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.sparse_(w, sparsity=0.1)
"""
if tensor.ndimension() != 2:
    raise ValueError("Only tensors with 2 dimensions are supported")

rows, cols = tensor.shape
num_zeros = int(math.ceil(sparsity * rows))

with torch.no_grad():
    tensor.normal_(0, std, generator=generator)
    for col_idx in range(cols):
        row_indices = torch.randperm(rows)
        zero_indices = row_indices[:num_zeros]
        tensor[zero_indices, col_idx] = 0
return tensor

for backward compatibility

def _make_deprecate(meth): new_name = meth.name old_name = new_name[:-1]

def deprecated_init(*args, **kwargs):
    warnings.warn(
        f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
        FutureWarning,
        stacklevel=2,
    )
    return meth(*args, **kwargs)

deprecated_init.__doc__ = rf"""
{old_name}(...)

.. warning::
    This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.

See :func:`~torch.nn.init.{new_name}` for details."""
deprecated_init.__name__ = old_name
return deprecated_init

uniform = make_deprecate(uniform) normal = make_deprecate(normal) constant = make_deprecate(constant) eye = make_deprecate(eye) dirac = make_deprecate(dirac) xavier_uniform = make_deprecate(xavier_uniform) xavier_normal = make_deprecate(xavier_normal) kaiming_uniform = make_deprecate(kaiming_uniform) kaiming_normal = make_deprecate(kaiming_normal) orthogonal = make_deprecate(orthogonal) sparse = make_deprecate(sparse)