TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation (original) (raw)
September 24, 2021, 10:46pm 1
In Next Steps for PyTorch Compilers, we laid out a vision of deploying eager mode PyTorch to more production settings and investing in using compilers to make eager mode faster and easier to maintain. This move away from graph mode makes some things a lot harder. For example, simple fusions that cross operator boundaries are at first glance not possible without users modifying their models. Lazy Tensors is one way to recapture these optimization opportunities. However, because it exists below the dispatcher, it cannot remove the overheads from Python and the upper levels of PyTorch stack — so it may not be a good choice for smaller, overhead-bound models.
TorchDynamo is an early experiment that radically rethinks the approach for recapturing these optimization opportunities. It hooks into the frame evaluation API in CPython to dynamically modify Python bytecode right before it is executed. This is analogous to what DynamoRIO does by dynamically modifying x86 machine code. TorchDynamo dynamically rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with a user-defined compiler. It creates this FX Graph through bytecode analysis, not tracing, and is designed to generating smaller graph fragments that can be mixed with Python execution. This approach has many advantages:
- It supports all Python because it can easily fall back to running the original bytecode. It depends on a fast eager mode in order to work properly, because the goal is to enhance eager mode rather than replace it.
- It is extremely low overhead, where it is possible to remove Python overheads from the original program by intercepting things at the very top of the stack.
- It does not introduce any added latency by deferring execution.
How TorchDynamo works
The figure above shows how TorchDynamo changes the behavior of CPython. TorchDynamo installs a custom eval frame function which performs dynamic bytecode analysis and transformation. The transformations insert calls to compiled FX Graphs into the bytecode. It protects reuse of these compiled artifacts by guards to ensure soundness.
To make this process more clear, let’s go through an example. Consider this toy code:
def fn(a, b):
x = a + b
x = x / 2.0
if x.sum() < 0:
return x * -1.0
return x
with torchdynamo.optimize(custom_compiler):
fn(torch.randn(10), torch.randn(10))
This toy example, results in the following original Python bytecode for fn():
0 LOAD_FAST 0 (a)
2 LOAD_FAST 1 (b)
4 BINARY_ADD
6 STORE_FAST 2 (x)
8 LOAD_FAST 2 (x)
10 LOAD_CONST 1 (2.0)
12 BINARY_TRUE_DIVIDE
14 STORE_FAST 2 (x)
16 LOAD_FAST 2 (x)
18 LOAD_METHOD 0 (sum)
20 CALL_METHOD 0
22 LOAD_CONST 2 (0)
24 COMPARE_OP 0 (<)
26 POP_JUMP_IF_FALSE 36
28 LOAD_FAST 2 (x)
30 LOAD_CONST 3 (-1.0)
32 BINARY_MULTIPLY
34 RETURN_VALUE
36 LOAD_FAST 2 (x)
38 RETURN_VALUE
TorchDynamo dynamically rewrites that bytecode as follows:
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 22
14 LOAD_GLOBAL 2 (__compiled_fn_1)
16 LOAD_FAST 2 (x)
18 CALL_FUNCTION 1
20 RETURN_VALUE
22 LOAD_FAST 2 (x)
24 RETURN_VALUE
This new bytecode calls two compiled FX graphs below. One can see that the control flow splits the program into two graphs.
__compiled_fn_0:
opcode name target args kwargs
------------- ------- --------------------------- ---------------- --------
placeholder a_0 a_0 () {}
placeholder b_1 b_1 () {}
call_function add <built-in function add> (a_0, b_1) {}
call_function truediv <built-in function truediv> (add, 2.0) {}
call_method sum_1 sum (truediv,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
__compiled_fn_1:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder x_4 x_4 () {}
call_function mul <built-in function mul> (x_4, -1.0) {}
output output output (mul,) {}
Finally, TorchDynamo generates two guards:
- Local arg “a” must be a torch.Tensor
- Local arg “b” must be a torch.Tensor
Failure of either of these guards triggers re-analysis and transformation.
If TorchDynamo were to encounter calls to non-PyTorch things, or some fancy Python structures it would leave those in the original bytecode. Thus, TorchDynamo opportunistically finds opportunities for optimization, without sacrificing the Python user experience.
Usage
Here is how the current API works:
def custom_compiler(graph: torch.fx.GraphModule) → Callable:
# do cool compiler optimizations here
return graph.forward
with torchdynamo.optimize(custom_compiler):
# any PyTorch code
# custom_compiler() is called to optimize extracted fragments
# should reach a fixed point where nothing new is compiled
# Optionally:
with torchdynamo.run():
# any PyTorch code
# previosly compiled artifacts are reused
# provides a quiescence guarantee, without compiles
You define your compiler function (which compiles an FX graph to a python callable), then wrap the code you want TorchDynamo to optimize in a torchdynamo.optimize
context manager. This should be all you need. In the cases where you want to make sure there are no added compile warmup code, we provide torchdynamo.run
to reuse prior optimizations from torchdynamo.optimize
, but not trigger any new compiles.
Early results
This project is still very early, so we haven’t tried applying optimizations yet and have been focusing on correctness, overhead, and coverage. We measured on 35 TorchBench models, using Python 3.8, and an Intel CPU. Raw results are here.
To summarize the results in the key focus areas:
- Correctness: 100%
- Correctness is by far this most important metric. It is how many models run and produce the right answer. The goal here is to make zero sacrifices to user experience and exist on the maximum usability end of the Pareto optimal curve. This is still early work, so there are surely bugs/gaps — though running on TorchBench gives some confidence it works for a wide variety of models.
- Overhead: <1% average
- Checking guards and patching frame objects adds some overheads. On the measured models overheads are under 1% for most models, and actually speeds many models up slightly. This is without doing any optimizations in the FX compiler function, so we are paying all the costs but getting no benefits. This metric is worst case scenario. Later on a focus will be using TorchDynamo to apply optimizations and get speedups.
- Coverage: 60% of ops, 64% of time
- The final metric is how many ops TorchDynamo captures, versus total ops in the whole model. This early version is able to capture 60% of all ops (which account for 64% of time). There are some models where 0% is captured, other models where 100% is captured, and most models are somewhere in between. There are still many missing features to add, so this is the current main area of focus for improvement.
Next Steps
Check out the source code here. This is still an experimental prototype, so use at your own risk. If you want to contribute please reach out to us.
There is still a ton of work left to do, so stay tuned for future updates that (hopefully) include higher coverage and some applications resulting in speedups!
Does that mean TorchDynamo only supports the scenarios in Inference? What is the roadmap of TorchDynamo to support PyTorch compilers for training?
jansel January 4, 2022, 2:58am 3
Support for training is planned and a key priority for TorchDynamo. It will be based on the (also experimental) compilation / AOT autograd work found in the functorch repo.
hgt312 January 19, 2022, 8:21am 4
Hi! I am interested in the training support. Will the training support contain optimizer? If so, will the weight update logics be seperated or all things (fwd+bwd+optim) in a single graph?
jansel January 19, 2022, 8:57pm 5
Training support is not finished yet. I don’t see anything about optimizers that would prevent them from being captured with minor tweaks.
Regarding single whole-program graphs, TorchDynamo generates single graphs often – but there is no guarantee you will get a whole program graph and that it not the goal. The design philosophy is mixed mode execution working with Python and prioritizing preserving the usability of PyTorch. Tons of things will result in graph breaks including: using Python types (e.g. Tensor.item, Tensor.tolist, torch.any, etc); calling external C libraries (e.g. numpy); printing/logging; control flow (e.g. early stopping in training loop); constructing custom Python classes; and more. If you absolutely require whole program graphs above all else, then a different approach, like AOT tracing or Lazy Tensors, might be a better fit.
elvinagam November 15, 2022, 4:49am 6
just checking out the google docs file attached which lists the benchmarked models. Obviously, those architectures are not really getting any significant speedups with torchdynamo? or Is there any updated list of speedups? @jansel
jansel November 15, 2022, 9:04am 7
Nice job!
I’m interested in the model quantization part. Can we implement this feature on model quantization? I mean quantization by torch.FX still needs Symbolic Tracing to acquire the graph which the Dynamic Control Flow problem may block. As I slightly know about TorchDynamo, it looks like works in runtime only. Hope I clarify my question well.