torch.nn — PyTorch 2.7 documentation (original) (raw)

These are the basic building blocks for graphs:

torch.nn

Pooling layers

nn.MaxPool1d Applies a 1D max pooling over an input signal composed of several input planes.
nn.MaxPool2d Applies a 2D max pooling over an input signal composed of several input planes.
nn.MaxPool3d Applies a 3D max pooling over an input signal composed of several input planes.
nn.MaxUnpool1d Computes a partial inverse of MaxPool1d.
nn.MaxUnpool2d Computes a partial inverse of MaxPool2d.
nn.MaxUnpool3d Computes a partial inverse of MaxPool3d.
nn.AvgPool1d Applies a 1D average pooling over an input signal composed of several input planes.
nn.AvgPool2d Applies a 2D average pooling over an input signal composed of several input planes.
nn.AvgPool3d Applies a 3D average pooling over an input signal composed of several input planes.
nn.FractionalMaxPool2d Applies a 2D fractional max pooling over an input signal composed of several input planes.
nn.FractionalMaxPool3d Applies a 3D fractional max pooling over an input signal composed of several input planes.
nn.LPPool1d Applies a 1D power-average pooling over an input signal composed of several input planes.
nn.LPPool2d Applies a 2D power-average pooling over an input signal composed of several input planes.
nn.LPPool3d Applies a 3D power-average pooling over an input signal composed of several input planes.
nn.AdaptiveMaxPool1d Applies a 1D adaptive max pooling over an input signal composed of several input planes.
nn.AdaptiveMaxPool2d Applies a 2D adaptive max pooling over an input signal composed of several input planes.
nn.AdaptiveMaxPool3d Applies a 3D adaptive max pooling over an input signal composed of several input planes.
nn.AdaptiveAvgPool1d Applies a 1D adaptive average pooling over an input signal composed of several input planes.
nn.AdaptiveAvgPool2d Applies a 2D adaptive average pooling over an input signal composed of several input planes.
nn.AdaptiveAvgPool3d Applies a 3D adaptive average pooling over an input signal composed of several input planes.

Non-linear Activations (weighted sum, nonlinearity)

nn.ELU Applies the Exponential Linear Unit (ELU) function, element-wise.
nn.Hardshrink Applies the Hard Shrinkage (Hardshrink) function element-wise.
nn.Hardsigmoid Applies the Hardsigmoid function element-wise.
nn.Hardtanh Applies the HardTanh function element-wise.
nn.Hardswish Applies the Hardswish function, element-wise.
nn.LeakyReLU Applies the LeakyReLU function element-wise.
nn.LogSigmoid Applies the Logsigmoid function element-wise.
nn.MultiheadAttention Allows the model to jointly attend to information from different representation subspaces.
nn.PReLU Applies the element-wise PReLU function.
nn.ReLU Applies the rectified linear unit function element-wise.
nn.ReLU6 Applies the ReLU6 function element-wise.
nn.RReLU Applies the randomized leaky rectified linear unit function, element-wise.
nn.SELU Applies the SELU function element-wise.
nn.CELU Applies the CELU function element-wise.
nn.GELU Applies the Gaussian Error Linear Units function.
nn.Sigmoid Applies the Sigmoid function element-wise.
nn.SiLU Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
nn.Mish Applies the Mish function, element-wise.
nn.Softplus Applies the Softplus function element-wise.
nn.Softshrink Applies the soft shrinkage function element-wise.
nn.Softsign Applies the element-wise Softsign function.
nn.Tanh Applies the Hyperbolic Tangent (Tanh) function element-wise.
nn.Tanhshrink Applies the element-wise Tanhshrink function.
nn.Threshold Thresholds each element of the input Tensor.
nn.GLU Applies the gated linear unit function.

Non-linear Activations (other)

nn.Softmin Applies the Softmin function to an n-dimensional input Tensor.
nn.Softmax Applies the Softmax function to an n-dimensional input Tensor.
nn.Softmax2d Applies SoftMax over features to each spatial location.
nn.LogSoftmax Applies the log⁡(Softmax(x))\log(\text{Softmax}(x)) function to an n-dimensional input Tensor.
nn.AdaptiveLogSoftmaxWithLoss Efficient softmax approximation.

Recurrent Layers

nn.RNNBase Base class for RNN modules (RNN, LSTM, GRU).
nn.RNN Apply a multi-layer Elman RNN with tanh⁡\tanh or ReLU\text{ReLU} non-linearity to an input sequence.
nn.LSTM Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence.
nn.GRU Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
nn.RNNCell An Elman RNN cell with tanh or ReLU non-linearity.
nn.LSTMCell A long short-term memory (LSTM) cell.
nn.GRUCell A gated recurrent unit (GRU) cell.

Linear Layers

nn.Identity A placeholder identity operator that is argument-insensitive.
nn.Linear Applies an affine linear transformation to the incoming data: y=xAT+by = xA^T + b.
nn.Bilinear Applies a bilinear transformation to the incoming data: y=x1TAx2+by = x_1^T A x_2 + b.
nn.LazyLinear A torch.nn.Linear module where in_features is inferred.

Sparse Layers

nn.Embedding A simple lookup table that stores embeddings of a fixed dictionary and size.
nn.EmbeddingBag Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings.

Distance Functions

nn.CosineSimilarity Returns cosine similarity between x1x_1 and x2x_2, computed along dim.
nn.PairwiseDistance Computes the pairwise distance between input vectors, or between columns of input matrices.

Loss Functions

nn.L1Loss Creates a criterion that measures the mean absolute error (MAE) between each element in the input xx and target yy.
nn.MSELoss Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input xx and target yy.
nn.CrossEntropyLoss This criterion computes the cross entropy loss between input logits and target.
nn.CTCLoss The Connectionist Temporal Classification loss.
nn.NLLLoss The negative log likelihood loss.
nn.PoissonNLLLoss Negative log likelihood loss with Poisson distribution of target.
nn.GaussianNLLLoss Gaussian negative log likelihood loss.
nn.KLDivLoss The Kullback-Leibler divergence loss.
nn.BCELoss Creates a criterion that measures the Binary Cross Entropy between the target and the input probabilities:
nn.BCEWithLogitsLoss This loss combines a Sigmoid layer and the BCELoss in one single class.
nn.MarginRankingLoss Creates a criterion that measures the loss given inputs x1x1, x2x2, two 1D mini-batch or 0D Tensors, and a label 1D mini-batch or 0D Tensor yy (containing 1 or -1).
nn.HingeEmbeddingLoss Measures the loss given an input tensor xx and a labels tensor yy (containing 1 or -1).
nn.MultiLabelMarginLoss Creates a criterion that optimizes a multi-class multi-classification hinge loss (margin-based loss) between input xx (a 2D mini-batch Tensor) and output yy (which is a 2D Tensor of target class indices).
nn.HuberLoss Creates a criterion that uses a squared term if the absolute element-wise error falls below delta and a delta-scaled L1 term otherwise.
nn.SmoothL1Loss Creates a criterion that uses a squared term if the absolute element-wise error falls below beta and an L1 term otherwise.
nn.SoftMarginLoss Creates a criterion that optimizes a two-class classification logistic loss between input tensor xx and target tensor yy (containing 1 or -1).
nn.MultiLabelSoftMarginLoss Creates a criterion that optimizes a multi-label one-versus-all loss based on max-entropy, between input xx and target yy of size (N,C)(N, C).
nn.CosineEmbeddingLoss Creates a criterion that measures the loss given input tensors x1x_1, x2x_2 and a Tensor label yy with values 1 or -1.
nn.MultiMarginLoss Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input xx (a 2D mini-batch Tensor) and output yy (which is a 1D tensor of target class indices, 0≤y≤x.size(1)−10 \leq y \leq \text{x.size}(1)-1):
nn.TripletMarginLoss Creates a criterion that measures the triplet loss given an input tensors x1x1, x2x2, x3x3 and a margin with a value greater than 00.
nn.TripletMarginWithDistanceLoss Creates a criterion that measures the triplet loss given input tensors aa, pp, and nn (representing anchor, positive, and negative examples, respectively), and a nonnegative, real-valued function ("distance function") used to compute the relationship between the anchor and positive example ("positive distance") and the anchor and negative example ("negative distance").

Vision Layers

nn.PixelShuffle Rearrange elements in a tensor according to an upscaling factor.
nn.PixelUnshuffle Reverse the PixelShuffle operation.
nn.Upsample Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
nn.UpsamplingNearest2d Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
nn.UpsamplingBilinear2d Applies a 2D bilinear upsampling to an input signal composed of several input channels.

Utilities

From the torch.nn.utils module:

Utility functions to clip parameter gradients.

clip_grad_norm_ Clip the gradient norm of an iterable of parameters.
clip_grad_norm Clip the gradient norm of an iterable of parameters.
clip_grad_value_ Clip the gradients of an iterable of parameters at specified value.
get_total_norm Compute the norm of an iterable of tensors.
clip_grads_with_norm_ Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm.

Utility functions to flatten and unflatten Module parameters to and from a single vector.

Utility functions to fuse Modules with BatchNorm modules.

fuse_conv_bn_eval Fuse a convolutional module and a BatchNorm module into a single, new convolutional module.
fuse_conv_bn_weights Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
fuse_linear_bn_eval Fuse a linear module and a BatchNorm module into a single, new linear module.
fuse_linear_bn_weights Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.

Utility functions to convert Module parameter memory formats.

Utility functions to apply and remove weight normalization from Module parameters.

weight_norm Apply weight normalization to a parameter in the given module.
remove_weight_norm Remove the weight normalization reparameterization from a module.
spectral_norm Apply spectral normalization to a parameter in the given module.
remove_spectral_norm Remove the spectral normalization reparameterization from a module.

Utility functions for initializing Module parameters.

skip_init Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.

Utility classes and functions for pruning Module parameters.

prune.BasePruningMethod Abstract base class for creation of new pruning techniques.
prune.PruningContainer Container holding a sequence of pruning methods for iterative pruning.
prune.Identity Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.
prune.RandomUnstructured Prune (currently unpruned) units in a tensor at random.
prune.L1Unstructured Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
prune.RandomStructured Prune entire (currently unpruned) channels in a tensor at random.
prune.LnStructured Prune entire (currently unpruned) channels in a tensor based on their Ln-norm.
prune.CustomFromMask
prune.identity Apply pruning reparametrization without pruning any units.
prune.random_unstructured Prune tensor by removing random (currently unpruned) units.
prune.l1_unstructured Prune tensor by removing units with the lowest L1-norm.
prune.random_structured Prune tensor by removing random channels along the specified dimension.
prune.ln_structured Prune tensor by removing channels with the lowest Ln-norm along the specified dimension.
prune.global_unstructured Globally prunes tensors corresponding to all parameters in parameters by applying the specified pruning_method.
prune.custom_from_mask Prune tensor corresponding to parameter called name in module by applying the pre-computed mask in mask.
prune.remove Remove the pruning reparameterization from a module and the pruning method from the forward hook.
prune.is_pruned Check if a module is pruned by looking for pruning pre-hooks.

Parametrizations implemented using the new parametrization functionality in torch.nn.utils.parameterize.register_parametrization().

Utility functions to parametrize Tensors on existing Modules. Note that these functions can be used to parametrize a given Parameter or Buffer given a specific function that maps from an input space to the parametrized space. They are not parameterizations that would transform an object into a parameter. See theParametrizations tutorialfor more information on how to implement your own parametrizations.

Utility functions to call a given Module in a stateless manner.

stateless.functional_call Perform a functional call on the module by replacing the module parameters and buffers with the provided ones.

Utility functions in other modules

nn.utils.rnn.PackedSequence Holds the data and list of batch_sizes of a packed sequence.
nn.utils.rnn.pack_padded_sequence Packs a Tensor containing padded sequences of variable length.
nn.utils.rnn.pad_packed_sequence Pad a packed batch of variable length sequences.
nn.utils.rnn.pad_sequence Pad a list of variable length Tensors with padding_value.
nn.utils.rnn.pack_sequence Packs a list of variable length Tensors.
nn.utils.rnn.unpack_sequence Unpack PackedSequence into a list of variable length Tensors.
nn.utils.rnn.unpad_sequence Unpad padded Tensor into a list of variable length Tensors.
nn.Flatten Flattens a contiguous range of dims into a tensor.
nn.Unflatten Unflattens a tensor dim expanding it to a desired shape.

Quantized Functions

Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the Quantization documentation.

Lazy Modules Initialization

Aliases

The following are aliases to their counterparts in torch.nn: