GitHub - pytorch/torchdynamo at 0b8aaf340dad4777a080ef24bf09623f1aa6f3dd (original) (raw)

TorchDynamo

TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with a customizable backend. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends to get the best of both worlds: usability and performance.

For more on TorchDynamo you can read our posts on PyTorch dev-discussor watch a deep-dive video.

This repository also hosts TorchInductor, which is TorchDynamo backend able to translate an FX Graph into Triton for GPUs or C++/OpenMPfor CPUs. We have a training performance dashboard comparing the performance of different training backends. You can read more in theTorchInductor post on PyTorch dev-discuss.

TorchDynamo is experimental and under active development. You are welcome to try it out and contribute, but should expect to find bugs and rough edges.

Requirements and Setup

Python 3.8 is recommended. Python 3.7 through 3.10 are supported and tested.

TorchDynamo requires the latest development/nightly build of PyTorch or building PyTorch from source. TorchDynamo also requires functorchfrom the PyTorch repository. The Makefile target make setup_nightly_gpucontain the commands used by our CI to setup dependencies.

The 1.12 branch contains an older snapshot of TorchDynamo that works on PyTorch 1.12. However, this is missing the latest features and not recommended.

Other development requirements can be installed with:

pip install -r requirements.txt

Install TorchDynamo with:

Usage Example

Here is a basic example of how to use TorchDynamo. One can decorate a function or a method using torchdynamo.optimize to enable TorchDynamo optimization.

import torch import torchdynamo

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable

@torchdynamo.optimize(my_compiler) def fn(x, y): a = torch.cos(x) b = torch.sin(y) return a + b

fn(torch.randn(10), torch.randn(10))

Running the above example produces this output

my_compiler() called with FX graph:
opcode         name    target                                                  args        kwargs
-------------  ------  ------------------------------------------------------  ----------  --------
placeholder    x       x                                                       ()          {}
placeholder    y       y                                                       ()          {}
call_function  cos     <built-in method cos of type object at 0x7f1a894649a8>  (x,)        {}
call_function  sin     <built-in method sin of type object at 0x7f1a894649a8>  (y,)        {}
call_function  add     <built-in function add>                                 (cos, sin)  {}
output         output  output                                                  ((add,),)   {}

This works for torch.nn.Module as well as shown below

import torch import torchdynamo

class MockModule(torch.nn.Module): def init(self): super().init() self.relu = torch.nn.ReLU()

def forward(self, x):
    return self.relu(torch.cos(x))

mod = MockModule() optimized_mod = torchdynamo.optimize(my_compiler)(mod) optimized_mod(torch.randn(10))

In the above examples, TorchDynamo uses a custom compiler (also referred to as backend in the rest of the doc) my_compiler that just prints the Fx GraphModule extracted by TorchDynamo's bytecode analysis, and returns theforward callable. One could write new compilers in a similar fashion.

Let's take a look at one more example with control flow.

from typing import List import torch import torchdynamo

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable

@torchdynamo.optimize(my_compiler) def toy_example(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b

for _ in range(100): toy_example(torch.randn(10), torch.randn(10))

Running this example produces the following output:

my_compiler() called with FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f8d259298a0>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}

my_compiler() called with FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    b       b                        ()           {}
placeholder    x       x                        ()           {}
call_function  mul     <built-in function mul>  (b, -1)      {}
call_function  mul_1   <built-in function mul>  (x, mul)     {}
output         output  output                   ((mul_1,),)  {}

my_compiler() called with FX graph:
opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    b       b                        ()         {}
placeholder    x       x                        ()         {}
call_function  mul     <built-in function mul>  (x, b)     {}
output         output  output                   ((mul,),)  {}

Note that the order of the last two graphs is nondeterministic depending on which one is encountered first by the just-in-time compiler.

Existing Backends

TorchDynamo has a growing list of backends, which can be found in backends.pyor torchdynamo.list_backends(). Note many backends require installing additional packages. Some of the most commonly used backends are

Debugging backends:

Training & inference backends:

Inference-only backends:

Training and AotAutograd

Torchdynamo supports training, using AotAutograd to capture backwards:

Current limitations:

Example

model = ... optimizer = ...

@torchdynamo.optimize("inductor") def training_iteration(...): return model(...)

for _ in range(100): loss = training_iteration(...) loss.backward() optimizer.step()

Troubleshooting

See Troubleshooting.

Adding Backends

One could replace my_compiler() in the examples above with something that generates faster code, for example one using optimize_for_inference:

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): scripted = torch.jit.trace(gm, example_inputs) return torch.jit.optimize_for_inference(scripted)

TorchDynamo also includes many backends, which can be found inbackends.py or torchdynamo.list_backends(). Note many backends require installing additional packages. You can combine these backends together with code like:

from torchdynamo.optimizations import BACKENDS

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): trt_compiled = BACKENDS["tensorrt"](gm, example_inputs) if trt_compiled is not None: return trt_compiled # first backend failed, try something else...

cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs)
if cudagraphs_compiled is not None:
    return cudagraphs_compiled

return gm.forward

Guards

TorchDynamo operates just-in-time and specializes graphs based on dynamic properties. For example, the first graph above has the following guards:

GUARDS:
 - local 'a' TENSOR_MATCH
 - local 'b' TENSOR_MATCH
 - global 'torch' FUNCTION_MATCH

If any of those guards fail, the graph will be recaptured and recompiled. The interesting guard type there is TENSOR_MATCH, which checks the following torch.Tensor properties:

*For sizes/strides you can disable this specialization by setting:

torchdynamo.config.dynamic_shapes = True

The full specialization mode allows the backend compiler to assume an entirely static graph. Unfortunately, most backends require this. Operators which return dynamic shapes will trigger a graph break when not in dynamic shape mode.

Run Mode / Quiescence Guarantee

In some cases, you may not want unexpected compiles after a program has warmed up. For example, if you are serving production traffic in a latency critical application. For this, TorchDynamo provides an alternate mode where prior compiled graphs are used, but no new ones are generated:

frozen_toy_example = torchdynamo.run(toy_example) frozen_toy_example(torch.randn(10), torch.randn(10))

Single Whole-Program Graph Mode

In some cases, you may want to ensure there are no graph breaks in your program to debug performance issues. You can turn graph breaks into errors by settingnopython=True:

@torchdynamo.optimize(my_compiler, nopython=True) def toy_example(a, b):

Which will trigger the following error in the example program above:

Traceback (most recent call last): ... torchdynamo.exc.Unsupported: generic_jump TensorVariable() Processing original code: File "example.py", line 7, in toy_example if b.sum() < 0:

Developer Setup

As background reading, I'd suggest looking at thePyTorch,functorch, andTorchBenchsetup docs. Since these projects work together in different ways.

The following instructions use Miniconda.

conda create --name=torchdynamo python=3.8
conda activate torchdynamo

# install pytorch nightlies
# for CUDA version, replace `cpuonly` with `cudatoolkit=11.6`
conda install pytorch torchvision torchaudio torchtext cpuonly -c pytorch-nightly
pip install -v "git+https://github.com/pytorch/pytorch.git@`python -c "import torch.version; print(torch.version.git_version)"`#subdirectory=functorch"

git clone git@github.com:pytorch/torchdynamo.git
cd torchdynamo
pip install -r requirements.txt

# check if everything works
make test

If see errors about missing symbols from guards.so, that may mean your C++ compiler is incompatible CUDA and/or with the one used to compile PyTorch. You may need to change your compiler version or build PyTorch from source.

Tests

Run tests with

To debug a specific test (with more debug prints) do:

Test on torchbenchmark models with:

python benchmarks/torchbench.py

Linting and Automatic Code Formatting

Lint Code style: black Imports: isort

Install format/linter deps with pip install -r requirements.txt, then:

make format # reformat all files (in-place) make lint # run the linters

License

TorchDynamo has a BSD-style license, as found in the LICENSE file.