(beta) Utilizing Torch Function modes with torch.compile — PyTorch Tutorials 2.7.0+cu126 documentation (original) (raw)
Note
Click hereto download the full example code
Author: Michael Lazos
This recipe covers how to use a key torch extensibility point,
torch function modes, in tandem with torch.compile
to override the behavior of torch operators, also know as ops, at trace time, with no runtime overhead.
Note
This recipe requires PyTorch 2.7.0 or later.
Rewriting a torch op (torch.add -> torch.mul)¶
For this example, we’ll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op.
import torch
exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0): print("Exiting because torch.compile is not supported on this device.") import sys sys.exit(0)
from torch.overrides import BaseTorchFunctionMode
Define our mode, Note: BaseTorchFunctionMode
implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode): def torch_function(self, func, types, args=(), kwargs=None): if func == torch.Tensor.add: func = torch.mul
return super().__torch_function__(func, types, args, kwargs)
@torch.compile() def test_fn(x, y): return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2) y = torch.rand_like(x)
with AddToMultiplyMode(): z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
The mode can also be used within the compiled region as well like this:
@torch.compile() def test_fn(x, y): with AddToMultiplyMode(): return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2) y = torch.rand_like(x) z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/pgo.py:465: UserWarning:
dynamo_pgo force disabled by torch._inductor.config.force_disable_caches
Conclusion¶
In this recipe we demonstrated how to override the behavior of torch.*
operators using torch function modes from within torch.compile
. This enables users to utilize the extensibility benefits of torch function modes without the runtime overhead of calling torch function on every op invocation.
- See Extending Torch API with Modes for other examples and background on Torch Function modes.
Total running time of the script: ( 0 minutes 7.932 seconds)