torch.jit.fork — PyTorch 2.7 documentation (original) (raw)

torch.jit.fork(func, *args, **kwargs)[source][source]

Create an asynchronous task executing func and a reference to the value of the result of this execution.

fork will return immediately, so the return value of func may not have been computed yet. To force completion of the task and access the return value invoke torch.jit.wait on the Future. fork invoked with a func which returns T is typed as torch.jit.Future[T]. fork calls can be arbitrarily nested, and may be invoked with positional and keyword arguments. Asynchronous execution will only occur when run in TorchScript. If run in pure python,fork will not execute in parallel. fork will also not execute in parallel when invoked while tracing, however the fork and wait calls will be captured in the exported IR Graph.

Warning

fork tasks will execute non-deterministically. We recommend only spawning parallel fork tasks for pure functions that do not modify their inputs, module attributes, or global state.

Parameters

Returns

a reference to the execution of func. The value Tcan only be accessed by forcing completion of func through torch.jit.wait.

Return type

torch.jit.Future[T]

Example (fork a free function):

import torch from torch import Tensor

def foo(a: Tensor, b: int) -> Tensor: return a + b

def bar(a): fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) return torch.jit.wait(fut)

script_bar = torch.jit.script(bar) input = torch.tensor(2)

only the scripted version executes asynchronously

assert script_bar(input) == bar(input)

trace is not run asynchronously, but fork is captured in IR

graph = torch.jit.trace(bar, (input,)).graph assert "fork" in str(graph)

Example (fork a module method):

import torch from torch import Tensor

class AddMod(torch.nn.Module): def forward(self, a: Tensor, b: int): return a + b

class Mod(torch.nn.Module): def init(self) -> None: super(self).init() self.mod = AddMod()

def forward(self, input):
    fut = torch.jit.fork(self.mod, a, b=2)
    return torch.jit.wait(fut)

input = torch.tensor(2) mod = Mod() assert mod(input) == torch.jit.script(mod).forward(input)