torch.utils.checkpoint — PyTorch 2.7 documentation (original) (raw)

Note

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward propagation. This can cause persistent states like the RNG state to be more advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=Falseto checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.

The stashing logic saves and restores the RNG state for CPU and another device type (infer the device type from Tensor arguments excluding CPU tensors by _infer_device_type) to the run_fn. If there are multiple device, device state will only be saved for devices of a single device type, and the remaining devices will be ignored. Consequently, if any checkpointed functions involve randomness, this may result in incorrect gradients. (Note that if CUDA devices are among the devices detected, it will be prioritized; otherwise, the first device encountered will be selected.) If there are no CPU-tensors, the default device type state (default value is cuda, and it could be set to other device by DefaultDeviceType) will be saved and restored. However, the logic has no way to anticipate if the user will move Tensors to a new device within the run_fn itself. Therefore, if you move Tensors to a new device (“new” meaning not belonging to the set of [current device + devices of Tensor arguments]) within run_fn, deterministic output compared to non-checkpointed passes is never guaranteed.

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source][source]

Checkpoint a model or part of the model.

Activation checkpointing is a technique that trades compute for memory. Instead of keeping tensors needed for backward alive until they are used in gradient computation during backward, forward computation in checkpointed regions omits saving tensors for backward and recomputes them during the backward pass. Activation checkpointing can be applied to any part of a model.

There are currently two checkpointing implementations available, determined by the use_reentrant parameter. It is recommended that you useuse_reentrant=False. Please refer the note below for a discussion of their differences.

Warning

If the function invocation during the backward pass differs from the forward pass, e.g., due to a global variable, the checkpointed version may not be equivalent, potentially causing an error being raised or leading to silently incorrect gradients.

Warning

The use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. If you are using the use_reentrant=True variant, please refer to the note below for important considerations and potential limitations.

Note

The reentrant variant of checkpoint (use_reentrant=True) and the non-reentrant variant of checkpoint (use_reentrant=False) differ in the following ways:

Parameters

Returns

Output of running function on *args

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source][source]

Checkpoint a sequential model to save memory.

Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will not store the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.

Warning

The use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. If you are using the use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False.

Parameters

Returns

Output of running functions sequentially on *inputs

Example

model = nn.Sequential(...) input_var = checkpoint_sequential(model, chunks, input_var)

torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source][source]

Context manager that sets whether checkpoint should print additional debug information when running. See the debug flag forcheckpoint() for more information. Note that when set, this context manager overrides the value of debug passed to checkpoint. To defer to the local setting, pass None to this context.

Parameters

enabled (bool) – Whether checkpoint should print debug information. Default is ‘None’.

class torch.utils.checkpoint.CheckpointPolicy(value)[source][source]

Enum for specifying the policy for checkpointing during backpropagation.

The following policies are supported:

Use MUST_* over PREFER_* to indicate that the policy should not be overridden by other subsystems like torch.compile.

Note

A policy function that always returns PREFER_RECOMPUTE is equivalent to vanilla checkpointing.

A policy function that returns PREFER_SAVE every op is NOT equivalent to not using checkpointing. Using such a policy would save additional tensors not limited to ones that are actually needed for gradient computation.

class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source][source]

Context passed to policy function during selective checkpointing.

This class is used to pass relevant metadata to the policy function during selective checkpointing. The metadata includes whether the current invocation of the policy function is during recomputation or not.

Example

def policy_fn(ctx, op, *args, **kwargs): print(ctx.is_recompute)

context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)

out = torch.utils.checkpoint.checkpoint( fn, x, y, use_reentrant=False, context_fn=context_fn, )

torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source][source]

Helper to avoid recomputing certain ops during activation checkpointing.

Use this with torch.utils.checkpoint.checkpoint to control which operations are recomputed during the backward pass.

Parameters

Returns

A tuple of two context managers.

Example

import functools

x = torch.rand(10, 10, requires_grad=True) y = torch.rand(10, 10, requires_grad=True)

ops_to_save = [ torch.ops.aten.mm.default, ]

def policy_fn(ctx, op, *args, **kwargs): if op in ops_to_save: return CheckpointPolicy.MUST_SAVE else: return CheckpointPolicy.PREFER_RECOMPUTE

context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)

or equivalently

context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)

def fn(x, y): return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y

out = torch.utils.checkpoint.checkpoint( fn, x, y, use_reentrant=False, context_fn=context_fn, )