torch.nn.utils.stateless.functional_call — PyTorch 2.7 documentation (original) (raw)

torch.nn.utils.stateless.functional_call(module, parameters_and_buffers, args=None, kwargs=None, *, tie_weights=True, strict=False)[source][source]

Perform a functional call on the module by replacing the module parameters and buffers with the provided ones.

Warning

This API is deprecated as of PyTorch 2.0 and will be removed in a future version of PyTorch. Please use torch.func.functional_call() instead, which is a drop-in replacement for this API.

Note

If the module has active parametrizations, passing a value in theparameters_and_buffers argument with the name set to the regular parameter name will completely disable the parametrization. If you want to apply the parametrization function to the value passed please set the key as {submodule_name}.parametrizations.{parameter_name}.original.

Note

If the module performs in-place operations on parameters/buffers, these will be reflected in the parameters_and_buffers input.

Example:

a = {'foo': torch.zeros(())} mod = Foo() # does self.foo = self.foo + 1 print(mod.foo) # tensor(0.) functional_call(mod, a, torch.ones(())) print(mod.foo) # tensor(0.) print(a['foo']) # tensor(1.)

Note

If the module has tied weights, whether or not functional_call respects the tying is determined by the tie_weights flag.

Example:

a = {'foo': torch.zeros(())} mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied print(mod.foo) # tensor(1.) mod(torch.zeros(())) # tensor(2.) functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} functional_call(mod, new_a, torch.zeros()) # tensor(0.)

Parameters

Returns

the result of calling module.

Return type

Any