torch.cuda.make_graphed_callables — PyTorch 2.7 documentation (original) (raw)
torch.cuda.make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None)[source][source]¶
Accept callables (functions or nn.Modules) and returns graphed versions.
Each graphed callable’s forward pass runs its source callable’s forward CUDA work as a CUDA graph inside a single autograd node.
The graphed callable’s forward pass also appends a backward node to the autograd graph. During backward, this node runs the callable’s backward work as a CUDA graph.
Therefore, each graphed callable should be a drop-in replacement for its source callable in an autograd-enabled training loop.
See Partial-network capture for detailed use and constraints.
If you pass a tuple of several callables, their captures will use the same memory pool. See Graph memory management for when this is appropriate.
Parameters
- callables (torch.nn.Module or Python function , or tuple of these) – Callable or callables to graph. See Graph memory management for when passing a tuple of callables is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order they’ll run in the live workload.
- sample_args (tuple of Tensors , or tuple of tuples of Tensors) – Samples args for each callable. If a single callable was passed,
sample_args
must be a single tuple of argument Tensors. If a tuple of callables was passed,sample_args
must be tuple of tuples of argument Tensors. - num_warmup_iters (int) – The number of warmup iterations. Currently,
DataDistributedParallel
needs 11 iterations for warm up. Default:3
. - allow_unused_input (bool) – If False, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. Defaults to False.
- pool (optional) – Token (returned by graph_pool_handle() orother_Graph_instance.pool()) that hints this graph may share memory with the indicated pool. See Graph memory management.
Note
The requires_grad
state of each Tensor in sample_args
must match the state that’s expected for the corresponding real input in the training loop.
Warning
This API is in beta and may change in future releases.
Warning
sample_args
for each callable must contain only Tensors. Other types are not allowed.
Warning
Returned callables do not support higher order differentiation (e.g., double backward).
Warning
In any Module passed to make_graphed_callables(), only parameters may be trainable. Buffers must have requires_grad=False
.
Warning
After you pass a torch.nn.Module through make_graphed_callables(), you may not add or remove any of that Module’s parameters or buffers.
Warning
torch.nn.Modules passed to make_graphed_callables() must not have module hooks registered on them at the time they are passed. However, registering hooks on modules after passing them through make_graphed_callables() is allowed.
Warning
When running a graphed callable, you must pass its arguments in the same order and format they appeared in that callable’s sample_args
.
Warning
The automatic mixed precision is supported in make_graphed_callables() only with disabled caching. The context manager torch.cuda.amp.autocast() must have cache_enabled=False.