Python Operator Authoring w/ NNC (original) (raw)

with @bertmaher and @ZolotukhinM

TLDR; we built a prototype that allows defining pointwise operators in Python backed by a NNC/FX-based JIT. It supports dynamic shapes and demonstrates up to a 3.5x speedup over torch.add on CPU and 1.5x on GPU. It can also be used as an explicit pointwise fusion API, where it beats the performance of existing TorchScript+NNC fuser by being lower overhead.

Motivation

With torch::deploy launching and the shift towards using PyTorch eager in more production settings, the PyTorch Compiler team has been thinking about how we can use compiler techniques to improve eager mode where we don’t have access to a whole program graph.

One of the ideas we have been thinking about is a new way to define operators in Python, and have those operators backed by a JIT compiler. This approach has many benefits:

It also has a few downsides:

User Interface

To help drive discussion, gather data, and make this idea more concrete, we built a working prototype. The prototype allows you to define a new pointwise operator in Python:

@torch.jit.te.pointwise_operator
def add(a, b):
    return a + b

This operator will be JIT compiled and is nearing feature parity with exiting TensorIterator-backed pointwise operators. It currently supports: CPU/GPU, broadcasting, type promotion, shape checking, strides, out variants, some aliasing, some backwards/autograd, etc. In the cases where there are missing features, it still tries to do all relevant checks to properly simulate the cost of implementing those features.

You can use this same interface to define custom fused operators, for example:

@torch.jit.te.pointwise_operator
def fused_addnorm(a, b, m, d):
    return (a + b - m) / d

We imagine extending this interface to handle composite operators, reductions, and other more complex types of ops.

Implementation and Specialization

The high level design is:

This allows us to bypass much of overhead and complexity of the existing call path, while still being able to handle the weird corner cases.

Picking the right data to go into the SpecializationKey is key here. If we don’t specialize enough, we will be forced to add checks to the fast-path. If we specialize too much, recompiling could become an issue. Here is an example of what the prototype SpecializationKey key looks like for add(rand(1, 8), rand(8, 8).transpose(0, 1)):

[SpecializationKey(
    alias_group=0,
    ndim=2,
    dtype=torch.float32,
    device=device(type='cpu'),
    layout=torch.strided,
    requires_grad=False,
    out=False,
    shape=['one', 'other'],
    stride=['contiguous', 'one']),
 SpecializationKey(
    alias_group=0,
    ndim=2,
    dtype=torch.float32,
    device=device(type='cpu'),
    layout=torch.strided,
    requires_grad=False,
    out=False,
    shape=['other', 'other'],
    stride=['one', 'transposed_contiguous'])]

There is one key for each input. The fields are as follows:

For speed this is implemented with packed bit-vectors. This key made need some tweaking, but we think it strikes a good balance as starting point for discussion.

Performance Results

The chart below show speedups comparing the performance of our prototype add() to the existing torch.add() on a wide variety of input types. We show both CPU (1-thread, Coffee Lake) and GPU (GTX 1070) results for sizes 1x1, 512x512, 8192x8192. The 1x1 size is meant to measure overheads, while the larger sizes are showing generated code quality. I ran each version hundreds of times (thousands for smaller sizes) and report the median speedup. You can find the definition of each experiment here.

With only a few exceptions, our prototype is faster or the same as torch.add. In most case the speedup is a few percent, but there are some cases where the speedup is up to 3.5x. For CPU, we see huge speedups on type promotion test cases. For GPU, we see big speedups (up to 1.5x) on broadcasting test cases.

Screenshot from 2021-06-24 08-56-03

Pointwise Fusion versus TorchScript

The last 3 bars show a small example of an explicit pointwise fusion. For the non-out-variant ones, we have an existing fuser in TorchScript that can also fuse this example:

# This prototype:
@torch.jit.te.pointwise_operator
def fused_addnorm(a, b, m, d):
    return (a + b - m) / d

# Same algorithm with the existing TorchScript-based fuser:
torch._C._jit_override_can_fuse_on_cpu(True)  
@torch.jit.script
def fused_addnorm_ts(a, b, m, d):
    return (a + b - m) / d
    

Here is a performance comparison showing speedups over eager (unfused) as a baseline and speedups over TorchScript fusion as a second baseline.

CPU speedups over either eager or TS
                       1x1    512x512 8192x8192
forward (vs unfused)   3.41x  1.99x   2.39x
forward (vs TS-fused)  3.16x  1.07x   1.00x
backward (vs unfused)  1.14x  1.57x   1.67x
backward (vs TS-fused) 1.70x  1.12x   1.00x

GPU speedups over either eager or TS
                       1x1    512x512 8192x8192
forward (vs unfused)   1.94x  1.36x   1.92x
forward (vs TS-fused)  1.48x  1.14x   1.00x
backward (vs unfused)  1.18x  1.18x   1.36x
backward (vs TS-fused) 1.33x  1.32x   1.00x

We can see that for 8192x8192, the two fusers perform the same (1.00x) for both CPU/GPU and forward/backward, but for smaller sizes this prototype is faster than TorchScript because it has much lower overheads. This prototype also supports dynamic shapes without needing to recompile, while the TorchScript version shape specializes.

Next Steps

This is still an early prototype and not yet production ready. There are still a ton of challenges left to overcome, optimization opportunities, and integration issues. We are hoping this will start a discussion and help get feedback on this new direction.