torch.cuda.graphs — PyTorch 2.7 documentation (original) (raw)

mypy: allow-untyped-defs

import gc import typing

import torch

from .._utils import _dummy_type

if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes torch._C.dict["_CUDAGraph"] = _dummy_type("_CUDAGraph") torch._C.dict["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") torch._C.dict["_cuda_isCurrentStreamCapturing"] = _dummy_type( "_cuda_isCurrentStreamCapturing" )

from torch._C import ( # noqa: F401 _cuda_isCurrentStreamCapturing, _CUDAGraph, _graph_pool_handle, )

[docs]def is_current_stream_capturing(): r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.

If a CUDA context does not exist on the current device, returns False without initializing the context.
"""
return _cuda_isCurrentStreamCapturing()

Python shim helps Sphinx process docstrings more reliably.

[docs]def graph_pool_handle(): r"""Return an opaque token representing the id of a graph memory pool.

See :ref:`Graph memory management<graph-memory-management>`.

.. warning::
    This API is in beta and may change in future releases.
"""
return _graph_pool_handle()

Python shim helps Sphinx process docstrings more reliably.

[docs]class CUDAGraph(torch._C._CUDAGraph): r"""Wrapper around a CUDA graph.

.. warning::
    This API is in beta and may change in future releases.
"""

def __new__(cls):
    return super().__new__(cls)

[docs] def capture_begin(self, pool=None, capture_error_mode="global"): r"""Begin capturing CUDA work on the current stream.

    Typically, you shouldn't call ``capture_begin`` yourself.
    Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
    which call ``capture_begin`` internally.

    Arguments:
        pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
            :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
            with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
        capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
            Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
            may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
            actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
            unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
    """  # noqa: B950
    super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)

[docs] def capture_end(self): r"""End CUDA graph capture on the current stream.

    After ``capture_end``, ``replay`` may be called on this instance.

    Typically, you shouldn't call ``capture_end`` yourself.
    Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
    which call ``capture_end`` internally.
    """
    super().capture_end()

[docs] def replay(self): r"""Replay the CUDA work captured by this graph.""" super().replay()

[docs] def reset(self): r"""Delete the graph currently held by this instance.""" super().reset()

[docs] def pool(self): r"""Return an opaque token representing the id of this graph's memory pool.

    This id can optionally be passed to another graph's ``capture_begin``,
    which hints the other graph may share the same memory pool.
    """
    return super().pool()

[docs] def enable_debug_mode(self): r"""Enable debugging mode for CUDAGraph.debug_dump.""" return super().enable_debug_mode()

[docs] def debug_dump(self, debug_path): r""" Arguments: debug_path (required): Path to dump the graph to.

    Calls a debugging function to dump the graph if the debugging is
    enabled via CUDAGraph.enable_debug_mode()
    """
    return super().debug_dump(debug_path)

[docs]class graph: r"""Context-manager that captures CUDA work into a :class:torch.cuda.CUDAGraph object for later replay.

See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
detailed use, and constraints.

Arguments:
    cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
    pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
        :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
        may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
    stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
        If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
    capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
        Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
        may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
        actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
        unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_

.. note::
    For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
    used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.

.. warning::
    This API is in beta and may change in future releases.

.. _cudaStreamCaptureMode:
    https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
"""  # noqa: B950

default_capture_stream: typing.Optional["torch.cuda.Stream"] = None

def __init__(
    self,
    cuda_graph,
    pool=None,
    stream=None,
    capture_error_mode: str = "global",
):
    # Lazy-init of default_capture_stream helps avoid circular-import errors.
    # Not thread safe, but graphs already have the general (explicitly documented)
    # restriction that only one capture may be underway at a time in the process.
    if self.__class__.default_capture_stream is None:
        self.__class__.default_capture_stream = torch.cuda.Stream()

    self.pool = () if pool is None else (pool,)
    self.capture_stream = (
        stream if stream is not None else self.__class__.default_capture_stream
    )
    assert self.capture_stream is not None
    self.stream_ctx = torch.cuda.stream(self.capture_stream)
    self.cuda_graph = cuda_graph
    self.capture_error_mode = capture_error_mode

def __enter__(self):
    # Free as much memory as we can for the graph
    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()

    # Stackoverflow seems comfortable with this pattern
    # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
    self.stream_ctx.__enter__()

    self.cuda_graph.capture_begin(
        *self.pool, capture_error_mode=self.capture_error_mode
    )

def __exit__(self, exc_type, exc_value, traceback):
    self.cuda_graph.capture_end()
    self.stream_ctx.__exit__(exc_type, exc_value, traceback)


    # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()

[docs]def make_graphed_callables( callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None ): r"""Accept callables (functions or :class:nn.Module<torch.nn.Module>\ s) and returns graphed versions.

Each graphed callable's forward pass runs its source callable's
forward CUDA work as a CUDA graph inside a single autograd node.

The graphed callable's forward pass also appends
a backward node to the autograd graph. During backward, this node runs the
callable's backward work as a CUDA graph.

Therefore, each graphed callable should be a drop-in replacement for its source callable
in an autograd-enabled training loop.

See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.

If you pass a tuple of several callables, their captures will use the same memory pool.
See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.

Arguments:
    callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
        See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
        is appropriate.  If you pass a tuple of callables, their order in the tuple must be the same order
        they'll run in the live workload.
    sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
        If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
        If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
    num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
        11 iterations for warm up. Default: ``3``.
    allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
        (and therefore their grad is always zero) is an error. Defaults to False.
    pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
        :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
        with the indicated pool.  See :ref:`Graph memory management<graph-memory-management>`.
.. note::
    The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
    that's expected for the corresponding real input in the training loop.

.. warning::
    This API is in beta and may change in future releases.

.. warning::
    ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.

.. warning::
    Returned callables do not support higher order differentiation (e.g., double backward).

.. warning::
    In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
    may be trainable. Buffers must have ``requires_grad=False``.

.. warning::
    After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
    you may not add or remove any of that Module's parameters or buffers.

.. warning::
    :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
    registered on them at the time they are passed. However, registering hooks on modules *after* passing them
    through :func:`~torch.cuda.make_graphed_callables` is allowed.

.. warning::
    When running a graphed callable, you must pass its arguments in the same order and format
    they appeared in that callable's ``sample_args``.

.. warning::
    The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
    caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
"""
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
    raise RuntimeError(
        "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
    )

just_one_callable = False

if not isinstance(callables, tuple):
    just_one_callable = True
    callables = (callables,)
    sample_args = (sample_args,)

flatten_sample_args = []

for c, args in zip(callables, sample_args):
    if isinstance(c, torch.nn.Module):
        assert (
            len(c._backward_hooks) == 0
            and len(c._forward_hooks) == 0
            and len(c._forward_pre_hooks) == 0
        ), (
            "Modules must not have hooks registered at the time they are passed. However, registering hooks "
            + "on modules after passing them through make_graphed_callables is allowed."
        )
        assert all(b.requires_grad is False for b in c.buffers()), (
            "In any :class:`~torch.nn.Module` passed to "
            + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
            + "``requires_grad=False``."
        )
    flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
    flatten_sample_args.append(tuple(flatten_arg))
    assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
        "In the beta API, sample_args "
        + "for each callable must contain only Tensors. Other types are not allowed."
    )

# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
per_callable_module_params = [
    tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
    for c in callables
]
per_callable_static_input_surfaces = [
    flatten_sample_args[i] + per_callable_module_params[i]
    for i in range(len(callables))
]

fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]

mempool = graph_pool_handle() if pool is None else pool

# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
    for func, args, static_input_surface in zip(
        callables, sample_args, per_callable_static_input_surfaces
    ):
        grad_inputs, outputs, outputs_grad = None, None, None
        for _ in range(num_warmup_iters):
            outputs = torch.utils._pytree.tree_leaves(func(*args))
            outputs_grad = tuple(o for o in outputs if o.requires_grad)
            if len(outputs_grad) > 0:
                grad_inputs = torch.autograd.grad(
                    outputs=outputs_grad,
                    inputs=tuple(
                        i for i in static_input_surface if i.requires_grad
                    ),
                    grad_outputs=tuple(
                        torch.empty_like(o) for o in outputs if o.requires_grad
                    ),
                    only_inputs=True,
                    allow_unused=allow_unused_input,
                )
        for v in [outputs, outputs_grad, grad_inputs]:
            del v

torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
    with torch.cuda.graph(fwd_graph, pool=mempool):
        outputs = func(*args)

    flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
    per_callable_static_outputs.append(tuple(flatten_outputs))
    per_callable_output_unflatten_spec.append(spec)

# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph in zip(
    reversed(per_callable_static_input_surfaces),
    reversed(per_callable_static_outputs),
    reversed(bwd_graphs),
):
    # For now, assumes all static_outputs require grad
    # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
    static_grad_outputs = tuple(
        torch.empty_like(o) if o.requires_grad else None for o in static_outputs
    )

    outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
    grad_inputs = None
    if len(outputs_grad) > 0:
        with torch.cuda.graph(bwd_graph, pool=mempool):
            grad_inputs = torch.autograd.grad(
                outputs=outputs_grad,
                inputs=tuple(i for i in static_input_surface if i.requires_grad),
                grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
                only_inputs=True,
                allow_unused=allow_unused_input,
            )

    # Constructs a tuple suitable for returning from Graphed.backward:
    # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
    # I couldn't think of a slick one-liner for this pattern.
    static_grad_inputs = []
    grad_idx = 0
    for arg in static_input_surface:
        if arg.requires_grad and grad_inputs is not None:
            static_grad_inputs.append(grad_inputs[grad_idx])
            grad_idx += 1
        else:
            static_grad_inputs.append(None)  # type: ignore[arg-type]
    static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]

    per_callable_static_grad_outputs.append(static_grad_outputs)
    per_callable_static_grad_inputs.append(static_grad_inputs)

# Reverses the most recent two lists
per_callable_static_grad_outputs.reverse()
per_callable_static_grad_inputs.reverse()
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

def make_graphed_autograd_function(
    fwd_graph,
    bwd_graph,
    module_params,
    len_user_args,
    output_unflatten_spec,
    static_input_surface,
    static_outputs,
    static_grad_outputs,
    static_grad_inputs,
):
    class Graphed(torch.autograd.Function):
        @staticmethod
        def forward(ctx, *inputs):
            # At this stage, only the user args may (potentially) be new tensors.
            for i in range(len_user_args):
                if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
                    static_input_surface[i].copy_(inputs[i])
            fwd_graph.replay()
            assert isinstance(static_outputs, tuple)
            return tuple(o.detach() for o in static_outputs)

        @staticmethod
        @torch.autograd.function.once_differentiable
        def backward(ctx, *grads):
            assert len(grads) == len(static_grad_outputs)
            for g, grad in zip(static_grad_outputs, grads):
                if g is not None:
                    # don't copy if autograd gods have been kind and the
                    # incoming grad is already in the right place
                    if g.data_ptr() != grad.data_ptr():
                        g.copy_(grad)
            bwd_graph.replay()

            # Input args that didn't require grad expect a None gradient.
            assert isinstance(static_grad_inputs, tuple)
            return tuple(
                b.detach() if b is not None else b for b in static_grad_inputs
            )

    def functionalized(*user_args):
        # Runs the autograd function with inputs == all inputs to the graph that might require grad
        # (explicit user args + module parameters)
        # Assumes module params didn't change since capture.
        flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
        out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
        return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)

    return functionalized

# Put together the final graphed callables
ret = []
for i, func in enumerate(callables):
    graphed = make_graphed_autograd_function(
        fwd_graphs[i],
        bwd_graphs[i],
        per_callable_module_params[i],
        per_callable_len_user_args[i],
        per_callable_output_unflatten_spec[i],
        per_callable_static_input_surfaces[i],
        per_callable_static_outputs[i],
        per_callable_static_grad_outputs[i],
        per_callable_static_grad_inputs[i],
    )

    if isinstance(func, torch.nn.Module):

        def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
            def new_fwd(*user_args):
                # If the module's training-or-eval state matches what we graphed,
                # run the graph, otherwise run the original forward method
                if func.training == graph_training_state:
                    return graphed(*user_args)
                else:
                    return orig_fwd(*user_args)

            return new_fwd

        func.forward = make_graphed_forward(func, func.training, graphed, func.forward)  # type: ignore[assignment]
        ret.append(func)
    else:
        ret.append(graphed)

if just_one_callable:
    return ret[0]

return tuple(ret)