Module — PyTorch 2.7 documentation (original) (raw)

class torch.nn.Module(*args, **kwargs)[source][source]

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn import torch.nn.functional as F

class Model(nn.Module): def init(self) -> None: super().init() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
    x = F.relu(self.conv1(x))
    return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables

training (bool) – Boolean represents whether this module is in training or evaluation mode.

add_module(name, module)[source][source]

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters

apply(fn)[source][source]

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also torch.nn.init).

Parameters

fn (Module -> None) – function to be applied to each submodule

Returns

self

Return type

Module

Example:

@torch.no_grad() def init_weights(m): print(m) if type(m) == nn.Linear: m.weight.fill_(1.0) print(m.weight) net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )

bfloat16()[source][source]

Casts all floating point parameters and buffers to bfloat16 datatype.

Note

This method modifies the module in-place.

Returns

self

Return type

Module

buffers(recurse=True)[source][source]

Return an iterator over module buffers.

Parameters

recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields

torch.Tensor – module buffer

Return type

Iterator[Tensor]

Example:

for buf in model.buffers(): print(type(buf), buf.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

children()[source][source]

Return an iterator over immediate children modules.

Yields

Module – a child module

Return type

Iterator[Module]

compile(*args, **kwargs)[source][source]

Compile this Module’s forward using torch.compile().

This Module’s __call__ method is compiled and all arguments are passed as-is to torch.compile().

See torch.compile() for details on the arguments for this function.

cpu()[source][source]

Move all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns

self

Return type

Module

cuda(device=None)[source][source]

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Parameters

device (int, optional) – if specified, all parameters will be copied to that device

Returns

self

Return type

Module

double()[source][source]

Casts all floating point parameters and buffers to double datatype.

Note

This method modifies the module in-place.

Returns

self

Return type

Module

eval()[source][source]

Set the module in evaluation mode.

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e. whether they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See Locally disabling gradient computation for a comparison between.eval() and several similar mechanisms that may be confused with it.

Returns

self

Return type

Module

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type

str

float()[source][source]

Casts all floating point parameters and buffers to float datatype.

Note

This method modifies the module in-place.

Returns

self

Return type

Module

forward(*input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_buffer(target)[source][source]

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Parameters

target (str) – The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns

The buffer referenced by target

Return type

torch.Tensor

Raises

AttributeError – If the target string references an invalid path or resolves to something that is not a buffer

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns

Any extra state to store in the module’s state_dict

Return type

object

get_parameter(target)[source][source]

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Parameters

target (str) – The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns

The Parameter referenced by target

Return type

torch.nn.Parameter

Raises

AttributeError – If the target string references an invalid path or resolves to something that is not annn.Parameter

get_submodule(target)[source][source]

Return the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )

(The diagram shows an nn.Module A. A which has a nested submodule net_b, which itself has two submodules net_cand linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would callget_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query againstnamed_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters

target (str) – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)

Returns

The submodule referenced by target

Return type

torch.nn.Module

Raises

AttributeError – If at any point along the path resulting from the target string the (sub)path resolves to a non-existent attribute name or an object that is not an instance of nn.Module.

half()[source][source]

Casts all floating point parameters and buffers to half datatype.

Note

This method modifies the module in-place.

Returns

self

Return type

Module

ipu(device=None)[source][source]

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on IPU while being optimized.

Note

This method modifies the module in-place.

Parameters

device (int, optional) – if specified, all parameters will be copied to that device

Returns

self

Return type

Module

load_state_dict(state_dict, strict=True, assign=False)[source][source]

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters

Returns

Return type

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise aRuntimeError.

modules()[source][source]

Return an iterator over all modules in the network.

Yields

Module – a module in the network

Return type

Iterator[Module]

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

l = nn.Linear(2, 2) net = nn.Sequential(l, l) for idx, m in enumerate(net.modules()): ... print(idx, '->', m)

0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)

mtia(device=None)[source][source]

Move all model parameters and buffers to the MTIA.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on MTIA while being optimized.

Note

This method modifies the module in-place.

Parameters

device (int, optional) – if specified, all parameters will be copied to that device

Returns

self

Return type

Module

named_buffers(prefix='', recurse=True, remove_duplicate=True)[source][source]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters

Yields

(str, torch.Tensor) – Tuple containing the name and buffer

Return type

Iterator[tuple[str, torch.Tensor]]

Example:

for name, buf in self.named_buffers(): if name in ['running_var']: print(buf.size())

named_children()[source][source]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields

(str, Module) – Tuple containing a name and child module

Return type

Iterator[tuple[str, ‘Module’]]

Example:

for name, module in model.named_children(): if name in ['conv4', 'conv5']: print(module)

named_modules(memo=None, prefix='', remove_duplicate=True)[source][source]

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters

Yields

(str, Module) – Tuple of name and module

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

l = nn.Linear(2, 2) net = nn.Sequential(l, l) for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m)

0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

named_parameters(prefix='', recurse=True, remove_duplicate=True)[source][source]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters

Yields

(str, Parameter) – Tuple containing the name and parameter

Return type

Iterator[tuple[str, torch.nn.parameter.Parameter]]

Example:

for name, param in self.named_parameters(): if name in ['bias']: print(param.size())

parameters(recurse=True)[source][source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters

recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields

Parameter – module parameter

Return type

Iterator[Parameter]

Example:

for param in model.parameters(): print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

register_backward_hook(hook)[source][source]

Register a backward hook on the module.

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_buffer(name, tensor, persistent=True)[source][source]

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_meanis not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’sstate_dict.

Buffers can be accessed as attributes using given names.

Parameters

Example:

self.register_buffer('running_mean', torch.zeros(num_features))

register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)[source][source]

Register a forward hook on the module.

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed thekwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:

hook(module, args, kwargs, output) -> None or modified output

Parameters

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)[source][source]

Register a forward pre-hook on the module.

The hook will be called every time before forward() is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

Parameters

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook, prepend=False)[source][source]

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Parameters

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook, prepend=False)[source][source]

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature:

hook(module, grad_output) -> tuple[Tensor] or None

The grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of grad_output in subsequent computations. Entries in grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Parameters

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_load_state_dict_post_hook(hook)[source][source]

Register a post-hook to be run after module’s load_state_dict() is called.

It should have the following signature::

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keysis a list of str containing the missing keys andunexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling load_state_dict() withstrict=True are affected by modifications the hook makes tomissing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns

a handle that can be used to remove the added hook by callinghandle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_load_state_dict_pre_hook(hook)[source][source]

Register a pre-hook to be run before module’s load_state_dict() is called.

It should have the following signature::

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950

Parameters

hook (Callable) – Callable hook that will be invoked before loading the state dict.

register_module(name, module)[source][source]

Alias for add_module().

register_parameter(name, param)[source][source]

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters

register_state_dict_post_hook(hook)[source][source]

Register a post-hook for the state_dict() method.

It should have the following signature::

hook(module, state_dict, prefix, local_metadata) -> None

The registered hooks can modify the state_dict inplace.

register_state_dict_pre_hook(hook)[source][source]

Register a pre-hook for the state_dict() method.

It should have the following signature::

hook(module, prefix, keep_vars) -> None

The registered hooks can be used to perform pre-processing before the state_dictcall is made.

requires_grad_(requires_grad=True)[source][source]

Change if autograd should record operations on parameters in this module.

This method sets the parameters’ requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See Locally disabling gradient computation for a comparison between.requires_grad_() and several similar mechanisms that may be confused with it.

Parameters

requires_grad (bool) – whether autograd should record operations on parameters in this module. Default: True.

Returns

self

Return type

Module

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state() for your module if you need to store extra state within itsstate_dict.

Parameters

state (dict) – Extra state from the state_dict

set_submodule(target, module, strict=False)[source][source]

Set the submodule given by target if it exists, otherwise throw an error.

Note

If strict is set to False (default), the method will replace an existing submodule or create a new submodule if the parent module exists. If strict is set to True, the method will only attempt to replace an existing submodule and throw an error if the submodule does not exist.

For example, let’s say you have an nn.Module A that looks like this:

A( (net_b): Module( (net_c): Module( (conv): Conv2d(3, 3, 3) ) (linear): Linear(3, 3) ) )

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_cand linear. net_c then has a submodule conv.)

To override the Conv2d with a new submodule Linear, you could call set_submodule("net_b.net_c.conv", nn.Linear(1, 1))where strict could be True or False

To add a new submodule Conv2d to the existing net_b module, you would call set_submodule("net_b.conv", nn.Conv2d(1, 1, 1)).

In the above if you set strict=True and callset_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True), an AttributeError will be raised because net_b does not have a submodule named conv.

Parameters

Raises

See torch.Tensor.share_memory_().

Return type

T

state_dict(*, destination: T_destination, prefix: str = '', keep_vars: bool = False) → T_destination[source][source]

state_dict(*, prefix: str = '', keep_vars: bool = False) → dict[str, Any]

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments fordestination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters

Returns

a dictionary containing a whole state of the module

Return type

dict

Example:

module.state_dict().keys() ['bias', 'weight']

to(device: Optional[Union[str, device, int]] = ..., dtype: Optional[dtype] = ..., non_blocking: bool = ...) → Self[source][source]

to(dtype: dtype, non_blocking: bool = ...) → Self

to(tensor: Tensor, non_blocking: bool = ...) → Self

Move and/or cast the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source][source]

to(dtype, non_blocking=False)[source][source]

to(tensor, non_blocking=False)[source][source]

to(memory_format=torch.channels_last)[source][source]

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype(if given). The integral parameters and buffers will be moveddevice, if that is given, but with dtypes unchanged. Whennon_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Parameters

Returns

self

Return type

Module

Examples:

linear = nn.Linear(2, 2) linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) gpu1 = torch.device("cuda:1") linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') cpu = torch.device("cpu") linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16)

linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

to_empty(*, device, recurse=True)[source][source]

Move the parameters and buffers to the specified device without copying storage.

Parameters

Returns

self

Return type

Module

train(mode=True)[source][source]

Set the module in training mode.

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc.

Parameters

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

Returns

self

Return type

Module

type(dst_type)[source][source]

Casts all parameters and buffers to dst_type.

Note

This method modifies the module in-place.

Parameters

dst_type (type or string) – the desired type

Returns

self

Return type

Module

xpu(device=None)[source][source]

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

Note

This method modifies the module in-place.

Parameters

device (int, optional) – if specified, all parameters will be copied to that device

Returns

self

Return type

Module

zero_grad(set_to_none=True)[source][source]

Reset gradients of all model parameters.

See similar function under torch.optim.Optimizer for more context.

Parameters

set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details.