torch.func.functional_call — PyTorch 2.7 documentation (original) (raw)

torch.func.functional_call(module, parameter_and_buffer_dicts, args=None, kwargs=None, *, tie_weights=True, strict=False)[source]

Performs a functional call on the module by replacing the module parameters and buffers with the provided ones.

Note

If the module has active parametrizations, passing a value in theparameter_and_buffer_dicts argument with the name set to the regular parameter name will completely disable the parametrization. If you want to apply the parametrization function to the value passed please set the key as {submodule_name}.parametrizations.{parameter_name}.original.

Note

If the module performs in-place operations on parameters/buffers, these will be reflected in the parameter_and_buffer_dicts input.

Example:

a = {'foo': torch.zeros(())} mod = Foo() # does self.foo = self.foo + 1 print(mod.foo) # tensor(0.) functional_call(mod, a, torch.ones(())) print(mod.foo) # tensor(0.) print(a['foo']) # tensor(1.)

Note

If the module has tied weights, whether or not functional_call respects the tying is determined by the tie_weights flag.

Example:

a = {'foo': torch.zeros(())} mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied print(mod.foo) # tensor(1.) mod(torch.zeros(())) # tensor(2.) functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} functional_call(mod, new_a, torch.zeros()) # tensor(0.)

An example of passing multiple dictionaries

a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer print(mod.weight) # tensor(...) print(mod.buffer) # tensor(...) x = torch.randn((1, 1)) print(x) functional_call(mod, a, x) # same as x print(mod.weight) # same as before functional_call

And here is an example of applying the grad transform over the parameters of a model.

import torch import torch.nn as nn from torch.func import functional_call, grad

x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3)

def compute_loss(params, x, t): y = functional_call(model, params, x) return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)

Note

If the user does not need grad tracking outside of grad transforms, they can detach all of the parameters for better performance and memory usage

Example:

detached_params = {k: v.detach() for k, v in model.named_parameters()} grad_weights = grad(compute_loss)(detached_params, x, t) grad_weights.grad_fn # None--it's not tracking gradients outside of grad

This means that the user cannot call grad_weight.backward(). However, if they don’t need autograd tracking outside of the transforms, this will result in less memory usage and faster speeds.

Parameters

Returns

the result of calling module.

Return type

Any