TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation (original) (raw)

September 24, 2021, 10:46pm 1

TorchDynamo

In Next Steps for PyTorch Compilers, we laid out a vision of deploying eager mode PyTorch to more production settings and investing in using compilers to make eager mode faster and easier to maintain. This move away from graph mode makes some things a lot harder. For example, simple fusions that cross operator boundaries are at first glance not possible without users modifying their models. Lazy Tensors is one way to recapture these optimization opportunities. However, because it exists below the dispatcher, it cannot remove the overheads from Python and the upper levels of PyTorch stack — so it may not be a good choice for smaller, overhead-bound models.

TorchDynamo is an early experiment that radically rethinks the approach for recapturing these optimization opportunities. It hooks into the frame evaluation API in CPython to dynamically modify Python bytecode right before it is executed. This is analogous to what DynamoRIO does by dynamically modifying x86 machine code. TorchDynamo dynamically rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with a user-defined compiler. It creates this FX Graph through bytecode analysis, not tracing, and is designed to generating smaller graph fragments that can be mixed with Python execution. This approach has many advantages:

How TorchDynamo works

The figure above shows how TorchDynamo changes the behavior of CPython. TorchDynamo installs a custom eval frame function which performs dynamic bytecode analysis and transformation. The transformations insert calls to compiled FX Graphs into the bytecode. It protects reuse of these compiled artifacts by guards to ensure soundness.

To make this process more clear, let’s go through an example. Consider this toy code:

def fn(a, b):
    x = a + b
    x = x / 2.0
    if x.sum() < 0:
        return x * -1.0
    return x
 
with torchdynamo.optimize(custom_compiler):   
   fn(torch.randn(10), torch.randn(10))

This toy example, results in the following original Python bytecode for fn():

 0  LOAD_FAST 0 (a)
 2  LOAD_FAST 1 (b)
 4  BINARY_ADD
 6  STORE_FAST 2 (x)

 8  LOAD_FAST 2 (x)
 10 LOAD_CONST 1 (2.0)
 12 BINARY_TRUE_DIVIDE
 14 STORE_FAST 2 (x)

 16 LOAD_FAST 2 (x)
 18 LOAD_METHOD 0 (sum)
 20 CALL_METHOD 0
 22 LOAD_CONST 2 (0)
 24 COMPARE_OP 0 (<)
 26 POP_JUMP_IF_FALSE 36

 28 LOAD_FAST 2 (x)
 30 LOAD_CONST 3 (-1.0)
 32 BINARY_MULTIPLY
 34 RETURN_VALUE

 36 LOAD_FAST 2 (x)
 38 RETURN_VALUE

TorchDynamo dynamically rewrites that bytecode as follows:

 0  LOAD_GLOBAL 1 (__compiled_fn_0)
 2  LOAD_FAST 0 (a)
 4  LOAD_FAST 1 (b)
 6  CALL_FUNCTION 2
 8  UNPACK_SEQUENCE 2
 10 STORE_FAST 2 (x)
 12 POP_JUMP_IF_FALSE 22
 
 14 LOAD_GLOBAL 2 (__compiled_fn_1)
 16 LOAD_FAST 2 (x)
 18 CALL_FUNCTION 1
 20 RETURN_VALUE

 22 LOAD_FAST 2 (x)
 24 RETURN_VALUE

This new bytecode calls two compiled FX graphs below. One can see that the control flow splits the program into two graphs.

__compiled_fn_0:
opcode         name     target                       args              kwargs
-------------  -------  ---------------------------  ----------------  --------
placeholder    a_0      a_0                          ()                {}
placeholder    b_1      b_1                          ()                {}
call_function  add      <built-in function add>      (a_0, b_1)        {}
call_function  truediv  <built-in function truediv>  (add, 2.0)        {}
call_method    sum_1    sum                          (truediv,)        {}
call_function  lt       <built-in function lt>       (sum_1, 0)        {}
output         output   output                       ((truediv, lt),)  {}

__compiled_fn_1:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    x_4     x_4                      ()           {}
call_function  mul     <built-in function mul>  (x_4, -1.0)  {}
output         output  output                   (mul,)       {}

Finally, TorchDynamo generates two guards:

Failure of either of these guards triggers re-analysis and transformation.

If TorchDynamo were to encounter calls to non-PyTorch things, or some fancy Python structures it would leave those in the original bytecode. Thus, TorchDynamo opportunistically finds opportunities for optimization, without sacrificing the Python user experience.

Usage

Here is how the current API works:

def custom_compiler(graph: torch.fx.GraphModule) → Callable:
    # do cool compiler optimizations here
    return graph.forward
    
with torchdynamo.optimize(custom_compiler):
    # any PyTorch code
    # custom_compiler() is called to optimize extracted fragments
    # should reach a fixed point where nothing new is compiled
    
# Optionally:
with torchdynamo.run():
    # any PyTorch code
    # previosly compiled artifacts are reused
    # provides a quiescence guarantee, without compiles

You define your compiler function (which compiles an FX graph to a python callable), then wrap the code you want TorchDynamo to optimize in a torchdynamo.optimize context manager. This should be all you need. In the cases where you want to make sure there are no added compile warmup code, we provide torchdynamo.run to reuse prior optimizations from torchdynamo.optimize, but not trigger any new compiles.

Early results

This project is still very early, so we haven’t tried applying optimizations yet and have been focusing on correctness, overhead, and coverage. We measured on 35 TorchBench models, using Python 3.8, and an Intel CPU. Raw results are here.

To summarize the results in the key focus areas:

Next Steps

Check out the source code here. This is still an experimental prototype, so use at your own risk. If you want to contribute please reach out to us.

There is still a ton of work left to do, so stay tuned for future updates that (hopefully) include higher coverage and some applications resulting in speedups!

Does that mean TorchDynamo only supports the scenarios in Inference? What is the roadmap of TorchDynamo to support PyTorch compilers for training?

jansel January 4, 2022, 2:58am 3

Support for training is planned and a key priority for TorchDynamo. It will be based on the (also experimental) compilation / AOT autograd work found in the functorch repo.

hgt312 January 19, 2022, 8:21am 4

Hi! I am interested in the training support. Will the training support contain optimizer? If so, will the weight update logics be seperated or all things (fwd+bwd+optim) in a single graph?

jansel January 19, 2022, 8:57pm 5

Training support is not finished yet. I don’t see anything about optimizers that would prevent them from being captured with minor tweaks.

Regarding single whole-program graphs, TorchDynamo generates single graphs often – but there is no guarantee you will get a whole program graph and that it not the goal. The design philosophy is mixed mode execution working with Python and prioritizing preserving the usability of PyTorch. Tons of things will result in graph breaks including: using Python types (e.g. Tensor.item, Tensor.tolist, torch.any, etc); calling external C libraries (e.g. numpy); printing/logging; control flow (e.g. early stopping in training loop); constructing custom Python classes; and more. If you absolutely require whole program graphs above all else, then a different approach, like AOT tracing or Lazy Tensors, might be a better fit.

elvinagam November 15, 2022, 4:49am 6

just checking out the google docs file attached which lists the benchmarked models. Obviously, those architectures are not really getting any significant speedups with torchdynamo? or Is there any updated list of speedups? @jansel

jansel November 15, 2022, 9:04am 7

Nice job!
I’m interested in the model quantization part. Can we implement this feature on model quantization? I mean quantization by torch.FX still needs Symbolic Tracing to acquire the graph which the Dynamic Control Flow problem may block. As I slightly know about TorchDynamo, it looks like works in runtime only. Hope I clarify my question well. 🤣