(beta) Utilizing Torch Function modes with torch.compile — PyTorch Tutorials 2.7.0+cu126 documentation (original) (raw)

Note

Click hereto download the full example code

Author: Michael Lazos

This recipe covers how to use a key torch extensibility point,

torch function modes, in tandem with torch.compile to override the behavior of torch operators, also know as ops, at trace time, with no runtime overhead.

Note

This recipe requires PyTorch 2.7.0 or later.

Rewriting a torch op (torch.add -> torch.mul)

For this example, we’ll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op.

import torch

exit cleanly if we are on a device that doesn't support torch.compile

if torch.cuda.get_device_capability() < (7, 0): print("Exiting because torch.compile is not supported on this device.") import sys sys.exit(0)

from torch.overrides import BaseTorchFunctionMode

Define our mode, Note: BaseTorchFunctionMode

implements the actual invocation of func(..)

class AddToMultiplyMode(BaseTorchFunctionMode): def torch_function(self, func, types, args=(), kwargs=None): if func == torch.Tensor.add: func = torch.mul

    return super().__torch_function__(func, types, args, kwargs)

@torch.compile() def test_fn(x, y): return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2) y = torch.rand_like(x)

with AddToMultiplyMode(): z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

The mode can also be used within the compiled region as well like this:

@torch.compile() def test_fn(x, y): with AddToMultiplyMode(): return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2) y = torch.rand_like(x) z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/pgo.py:465: UserWarning:

dynamo_pgo force disabled by torch._inductor.config.force_disable_caches

Conclusion

In this recipe we demonstrated how to override the behavior of torch.* operators using torch function modes from within torch.compile. This enables users to utilize the extensibility benefits of torch function modes without the runtime overhead of calling torch function on every op invocation.

Total running time of the script: ( 0 minutes 7.932 seconds)

Gallery generated by Sphinx-Gallery