ExportDB (original) (raw)

ExportDB is a centralized dataset of supported and unsupported export cases. It is targeted towards users who want to understand specifically what types of code are supported, the subtleties of export, and how to modify their existing code to be compatible with export. Note that this is not an exhaustive set of everything that is supported by exportdb, but it covers the most common and confusing use cases that users will run into.

If you have a feature that you think needs a stronger guarantee from us to support in export please create an issue in the pytorch/pytorch repo with a module:export tag.

Supported#

assume_constant_result#

Original source code:

mypy: allow-untyped-defs

import torch import torch._dynamo as torchdynamo

class AssumeConstantResult(torch.nn.Module): """ Applying assume_constant_result decorator to burn make non-tracable code as constant. """

@torchdynamo.assume_constant_result
def get_item(self, y):
    return y.int().item()

def forward(self, x, y):
    return x[: self.get_item(y)]

example_args = (torch.randn(3, 2), torch.tensor(4)) tags = {"torch.escape-hatch"} model = AssumeConstantResult()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "i64[]"): slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4); x = None return (slice_1,)

Graph signature: # inputs x: USER_INPUT y: USER_INPUT

# outputs
slice_1: USER_OUTPUT

Range constraints: {}

autograd_function#

Note

Tags:

Support Level: SUPPORTED

Original source code:

mypy: allow-untyped-defs

import torch

class MyAutogradFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone()

@staticmethod
def backward(ctx, grad_output):
    return grad_output + 1

class AutogradFunction(torch.nn.Module): """ TorchDynamo does not keep track of backward() on autograd functions. We recommend to use allow_in_graph to mitigate this problem. """

def forward(self, x):
    return MyAutogradFunction.apply(x)

example_args = (torch.randn(3, 2),) model = AutogradFunction()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): clone: "f32[3, 2]" = torch.ops.aten.clone.default(x); x = None return (clone,)

Graph signature: # inputs x: USER_INPUT

# outputs
clone: USER_OUTPUT

Range constraints: {}

class_method#

Note

Tags:

Support Level: SUPPORTED

Original source code:

mypy: allow-untyped-defs

import torch

class ClassMethod(torch.nn.Module): """ Class methods are inlined during tracing. """

@classmethod
def method(cls, x):
    return x + 1

def __init__(self) -> None:
    super().__init__()
    self.linear = torch.nn.Linear(4, 2)

def forward(self, x):
    x = self.linear(x)
    return self.method(x) * self.__class__.method(x) * type(self).method(x)

example_args = (torch.randn(3, 4),) model = ClassMethod()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_linear_weight: "f32[2, 4]", p_linear_bias: "f32[2]", x: "f32[3, 4]"): linear: "f32[3, 2]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None

             add: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)
        add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)

             mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1);  add = add_1 = None

             add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1);  linear = None

             mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2);  mul = add_2 = None
        return (mul_1,)

Graph signature: # inputs p_linear_weight: PARAMETER target='linear.weight' p_linear_bias: PARAMETER target='linear.bias' x: USER_INPUT

# outputs
mul_1: USER_OUTPUT

Range constraints: {}

cond_branch_class_method#

Original source code:

mypy: allow-untyped-defs

import torch

from functorch.experimental.control_flow import cond

class MySubModule(torch.nn.Module): def foo(self, x): return x.cos()

def forward(self, x):
    return self.foo(x)

class CondBranchClassMethod(torch.nn.Module): """ The branch functions (true_fn and false_fn) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables

This example demonstrates using class method in cond().

NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def __init__(self) -> None:
    super().__init__()
    self.subm = MySubModule()

def bar(self, x):
    return x.sin()

def forward(self, x):
    return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

example_args = (torch.randn(3),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondBranchClassMethod()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3]"): sin: "f32[3]" = torch.ops.aten.sin.default(x); x = None return (sin,)

Graph signature: # inputs x: USER_INPUT

# outputs
sin: USER_OUTPUT

Range constraints: {}

cond_branch_nested_function#

Original source code:

mypy: allow-untyped-defs

import torch

from functorch.experimental.control_flow import cond

class CondBranchNestedFunction(torch.nn.Module): """ The branch functions (true_fn and false_fn) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables

This example demonstrates using nested function in cond().

NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def forward(self, x):
    def true_fn(x):
        def inner_true_fn(y):
            return x + y

        return inner_true_fn(x)

    def false_fn(x):
        def inner_false_fn(y):
            return x - y

        return inner_false_fn(x)

    return cond(x.shape[0] < 10, true_fn, false_fn, [x])

example_args = (torch.randn(3),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondBranchNestedFunction()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3]"): add: "f32[3]" = torch.ops.aten.add.Tensor(x, x); x = None return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

cond_branch_nonlocal_variables#

Original source code:

mypy: allow-untyped-defs

import torch

from functorch.experimental.control_flow import cond

class CondBranchNonlocalVariables(torch.nn.Module): """ The branch functions (true_fn and false_fn) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables

This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.

The code below will not work because capturing closure variables is not supported.
```
my_tensor_var = x + 100
my_primitive_var = 3.14

def true_fn(y):
    nonlocal my_tensor_var, my_primitive_var
    return y + my_tensor_var + my_primitive_var

def false_fn(y):
    nonlocal my_tensor_var, my_primitive_var
    return y - my_tensor_var - my_primitive_var

return cond(x.shape[0] > 5, true_fn, false_fn, [x])
```

NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def forward(self, x):
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(x, y, z):
        return x + y + z

    def false_fn(x, y, z):
        return x - y - z

    return cond(
        x.shape[0] > 5,
        true_fn,
        false_fn,
        [x, my_tensor_var, torch.tensor(my_primitive_var)],
    )

example_args = (torch.randn(6),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondBranchNonlocalVariables()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, c_lifted_tensor_0: "f32[]", x: "f32[6]"): add: "f32[6]" = torch.ops.aten.add.Tensor(x, 100)

             lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
        detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None

             add_1: "f32[6]" = torch.ops.aten.add.Tensor(x, add);  x = add = None
        add_2: "f32[6]" = torch.ops.aten.add.Tensor(add_1, detach_);  add_1 = detach_ = None
        return (add_2,)

Graph signature: # inputs c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0' x: USER_INPUT

# outputs
add_2: USER_OUTPUT

Range constraints: {}

cond_closed_over_variable#

Original source code:

mypy: allow-untyped-defs

import torch

from functorch.experimental.control_flow import cond

class CondClosedOverVariable(torch.nn.Module): """ torch.cond() supports branches closed over arbitrary variables. """

def forward(self, pred, x):
    def true_fn(val):
        return x * 2

    def false_fn(val):
        return x - 2

    return cond(pred, true_fn, false_fn, [x + 1])

example_args = (torch.tensor(True), torch.randn(3, 2)) tags = {"torch.cond", "python.closure"} model = CondClosedOverVariable()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, pred: "b8[]", x: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1); add = None

             true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        cond = torch.ops.higher_order.cond(pred, true_graph_0, false_graph_0, (x,));  pred = true_graph_0 = false_graph_0 = x = None
        getitem: "f32[3, 2]" = cond[0];  cond = None
        return (getitem,)

    class true_graph_0(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                     mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2);  x = None
            return (mul,)

    class false_graph_0(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                     sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(x, 2);  x = None
            return (sub,)

Graph signature: # inputs pred: USER_INPUT x: USER_INPUT

# outputs
getitem: USER_OUTPUT

Range constraints: {}

cond_operands#

Original source code:

mypy: allow-untyped-defs

import torch

from torch.export import Dim

x = torch.randn(3, 2) y = torch.randn(2) dim0_x = Dim("dim0_x")

class CondOperands(torch.nn.Module): """ The operands passed to cond() must be: - a list of tensors - match arguments of true_fn and false_fn

NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def forward(self, x, y):
    def true_fn(x, y):
        return x + y

    def false_fn(x, y):
        return x - y

    return torch.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

example_args = (x, y) tags = { "torch.cond", "torch.dynamic-shape", } extra_inputs = (torch.randn(2, 2), torch.randn(2)) dynamic_shapes = {"x": {0: dim0_x}, "y": None} model = CondOperands()

torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s77, 2]", y: "f32[2]"): # sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)

             gt: "Sym(s77 > 2)" = sym_size_int_1 > 2;  sym_size_int_1 = None

             true_graph_0 = self.true_graph_0
        false_graph_0 = self.false_graph_0
        cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, y));  gt = true_graph_0 = false_graph_0 = x = y = None
        getitem: "f32[s77, 2]" = cond[0];  cond = None
        return (getitem,)

    class true_graph_0(torch.nn.Module):
        def forward(self, x: "f32[s77, 2]", y: "f32[2]"):
                     add: "f32[s77, 2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

    class false_graph_0(torch.nn.Module):
        def forward(self, x: "f32[s77, 2]", y: "f32[2]"):
                     sub: "f32[s77, 2]" = torch.ops.aten.sub.Tensor(x, y);  x = y = None
            return (sub,)

Graph signature: # inputs x: USER_INPUT y: USER_INPUT

# outputs
getitem: USER_OUTPUT

Range constraints: {s77: VR[0, int_oo]}

cond_predicate#

Original source code:

mypy: allow-untyped-defs

import torch

from functorch.experimental.control_flow import cond

class CondPredicate(torch.nn.Module): """ The conditional statement (aka predicate) passed to cond() must be one of the following: - torch.Tensor with a single element - boolean expression

NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def forward(self, x):
    pred = x.dim() > 2 and x.shape[2] > 10

    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

example_args = (torch.randn(6, 4, 3),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondPredicate()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[6, 4, 3]"): sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(x); x = None return (sin,)

Graph signature: # inputs x: USER_INPUT

# outputs
sin: USER_OUTPUT

Range constraints: {}

constrain_as_size_example#

Original source code:

mypy: allow-untyped-defs

import torch

class ConstrainAsSizeExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check_is_size is used for values that NEED to be used for constructing tensor. """

def forward(self, x):
    a = x.item()
    torch._check_is_size(a)
    torch._check(a <= 5)
    return torch.zeros((a, 5))

example_args = (torch.tensor(4),) tags = { "torch.dynamic-value", "torch.escape-hatch", } model = ConstrainAsSizeExample()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]"): item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None

         #
        sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None

             ge_1: "Sym(u0 >= 0)" = item >= 0
        _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
        le_1: "Sym(u0 <= 5)" = item <= 5
        _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False);  item = None
        return (zeros,)

Graph signature: # inputs x: USER_INPUT

# outputs
zeros: USER_OUTPUT

Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}

constrain_as_value_example#

Original source code:

mypy: allow-untyped-defs

import torch

class ConstrainAsValueExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check is used for values that don't need to be used for constructing tensor. """

def forward(self, x, y):
    a = x.item()
    torch._check(a >= 0)
    torch._check(a <= 5)

    if a < 6:
        return y.sin()
    return y.cos()

example_args = (torch.tensor(4), torch.randn(5, 5)) tags = { "torch.dynamic-value", "torch.escape-hatch", } model = ConstrainAsValueExample()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "f32[5, 5]"): item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None ge_1: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1: "Sym(u0 <= 5)" = item <= 5; item = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None

             sin: "f32[5, 5]" = torch.ops.aten.sin.default(y);  y = None
        return (sin,)

Graph signature: # inputs x: USER_INPUT y: USER_INPUT

# outputs
sin: USER_OUTPUT

Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}

decorator#

Note

Tags:

Support Level: SUPPORTED

Original source code:

mypy: allow-untyped-defs

import functools

import torch

def test_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) + 1

return wrapper

class Decorator(torch.nn.Module): """ Decorators calls are inlined into the exported function during tracing. """

@test_decorator
def forward(self, x, y):
    return x + y

example_args = (torch.randn(3, 2), torch.randn(3, 2)) model = Decorator()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, y); x = y = None

             add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None
        return (add_1,)

Graph signature: # inputs x: USER_INPUT y: USER_INPUT

# outputs
add_1: USER_OUTPUT

Range constraints: {}

dictionary#

Original source code:

mypy: allow-untyped-defs

import torch

class Dictionary(torch.nn.Module): """ Dictionary structures are inlined and flattened along tracing. """

def forward(self, x, y):
    elements = {}
    elements["x2"] = x * x
    y = y * elements["x2"]
    return {"y": y}

example_args = (torch.randn(3, 2), torch.tensor(4)) tags = {"python.data-structure"} model = Dictionary()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "i64[]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x); x = None

             mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(y, mul);  y = mul = None
        return (mul_1,)

Graph signature: # inputs x: USER_INPUT y: USER_INPUT

# outputs
mul_1: USER_OUTPUT

Range constraints: {}

dynamic_shape_assert#

Original source code:

mypy: allow-untyped-defs

import torch

class DynamicShapeAssert(torch.nn.Module): """ A basic usage of python assertion. """

def forward(self, x):
    # assertion with error message
    assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
    # assertion without error message
    assert x.shape[0] > 1
    return x

example_args = (torch.randn(3, 2),) tags = {"python.assert"} model = DynamicShapeAssert()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): return (x,)

Graph signature: # inputs x: USER_INPUT

# outputs
x: USER_OUTPUT

Range constraints: {}

dynamic_shape_constructor#

Original source code:

mypy: allow-untyped-defs

import torch

class DynamicShapeConstructor(torch.nn.Module): """ Tensor constructors should be captured with dynamic shape inputs rather than being baked in with static shape. """

def forward(self, x):
    return torch.zeros(x.shape[0] * 2)

example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape"} model = DynamicShapeConstructor()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): zeros: "f32[6]" = torch.ops.aten.zeros.default([6], device = device(type='cpu'), pin_memory = False) return (zeros,)

Graph signature: # inputs x: USER_INPUT

# outputs
zeros: USER_OUTPUT

Range constraints: {}

dynamic_shape_if_guard#

Original source code:

mypy: allow-untyped-defs

import torch

class DynamicShapeIfGuard(torch.nn.Module): """ if statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """

def forward(self, x):
    if x.shape[0] == 3:
        return x.cos()

    return x.sin()

example_args = (torch.randn(3, 2, 2),) tags = {"torch.dynamic-shape", "python.control-flow"} model = DynamicShapeIfGuard()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2, 2]"): cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x); x = None return (cos,)

Graph signature: # inputs x: USER_INPUT

# outputs
cos: USER_OUTPUT

Range constraints: {}

dynamic_shape_map#

Original source code:

mypy: allow-untyped-defs

import torch

from functorch.experimental.control_flow import map

class DynamicShapeMap(torch.nn.Module): """ functorch map() maps a function over the first tensor dimension. """

def forward(self, xs, y):
    def body(x, y):
        return x + y

    return map(body, xs, y)

example_args = (torch.randn(3, 2), torch.randn(2)) tags = {"torch.dynamic-shape", "torch.map"} model = DynamicShapeMap()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, xs: "f32[3, 2]", y: "f32[2]"): body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y]); body_graph_0 = xs = y = None getitem: "f32[3, 2]" = map_impl[0]; map_impl = None return (getitem,)

    class body_graph_0(torch.nn.Module):
        def forward(self, xs: "f32[2]", y: "f32[2]"):
                     add: "f32[2]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
            return (add,)

Graph signature: # inputs xs: USER_INPUT y: USER_INPUT

# outputs
getitem: USER_OUTPUT

Range constraints: {}

dynamic_shape_slicing#

Original source code:

mypy: allow-untyped-defs

import torch

class DynamicShapeSlicing(torch.nn.Module): """ Slices with dynamic shape arguments should be captured into the graph rather than being baked in. """

def forward(self, x):
    return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape"} model = DynamicShapeSlicing()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 1); x = None slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2); slice_1 = None return (slice_2,)

Graph signature: # inputs x: USER_INPUT

# outputs
slice_2: USER_OUTPUT

Range constraints: {}

dynamic_shape_view#

Original source code:

mypy: allow-untyped-defs

import torch

class DynamicShapeView(torch.nn.Module): """ Dynamic shapes should be propagated to view arguments instead of being baked into the exported graph. """

def forward(self, x):
    new_x_shape = x.size()[:-1] + (2, 5)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1)

example_args = (torch.randn(10, 10),) tags = {"torch.dynamic-shape"} model = DynamicShapeView()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[10, 10]"): view: "f32[10, 2, 5]" = torch.ops.aten.view.default(x, [10, 2, 5]); x = None

             permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
        return (permute,)

Graph signature: # inputs x: USER_INPUT

# outputs
permute: USER_OUTPUT

Range constraints: {}

fn_with_kwargs#

Original source code:

mypy: allow-untyped-defs

import torch

class FnWithKwargs(torch.nn.Module): """ Keyword arguments are not supported at the moment. """

def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
    out = pos0
    for arg in tuple0:
        out = out * arg
    for arg in myargs:
        out = out * arg
    out = out * mykw0
    out = out * mykwargs["input0"] * mykwargs["input1"]
    return out

example_args = ( torch.randn(4), (torch.randn(4), torch.randn(4)), *[torch.randn(4), torch.randn(4)] ) example_kwargs = { "mykw0": torch.randn(4), "input0": torch.randn(4), "input1": torch.randn(4), } tags = {"python.data-structure"} model = FnWithKwargs()

torch.export.export(model, example_args, example_kwargs)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, pos0: "f32[4]", tuple0_0: "f32[4]", tuple0_1: "f32[4]", myargs_0: "f32[4]", myargs_1: "f32[4]", mykw0: "f32[4]", input0: "f32[4]", input1: "f32[4]"): mul: "f32[4]" = torch.ops.aten.mul.Tensor(pos0, tuple0_0); pos0 = tuple0_0 = None mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, tuple0_1); mul = tuple0_1 = None

             mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, myargs_0);  mul_1 = myargs_0 = None
        mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, myargs_1);  mul_2 = myargs_1 = None

             mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, mykw0);  mul_3 = mykw0 = None

             mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, input0);  mul_4 = input0 = None
        mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, input1);  mul_5 = input1 = None
        return (mul_6,)

Graph signature: # inputs pos0: USER_INPUT tuple0_0: USER_INPUT tuple0_1: USER_INPUT myargs_0: USER_INPUT myargs_1: USER_INPUT mykw0: USER_INPUT input0: USER_INPUT input1: USER_INPUT

# outputs
mul_6: USER_OUTPUT

Range constraints: {}

list_contains#

Original source code:

mypy: allow-untyped-defs

import torch

class ListContains(torch.nn.Module): """ List containment relation can be checked on a dynamic shape or constants. """

def forward(self, x):
    assert x.size(-1) in [6, 2]
    assert x.size(0) not in [4, 5, 6]
    assert "monkey" not in ["cow", "pig"]
    return x + x

example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} model = ListContains()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x); x = None return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

list_unpack#

Original source code:

mypy: allow-untyped-defs

import torch

class ListUnpack(torch.nn.Module): """ Lists are treated as static construct, therefore unpacking should be erased after tracing. """

def forward(self, args: list[torch.Tensor]):
    """
    Lists are treated as static construct, therefore unpacking should be
    erased after tracing.
    """
    x, *y = args
    return x + y[0]

example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) tags = {"python.control-flow", "python.data-structure"} model = ListUnpack()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1); args_0 = args_1 = None return (add,)

Graph signature: # inputs args_0: USER_INPUT args_1: USER_INPUT args_2: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

nested_function#

Original source code:

mypy: allow-untyped-defs

import torch

class NestedFunction(torch.nn.Module): """ Nested functions are traced through. Side effects on global captures are not supported though. """

def forward(self, a, b):
    x = a + b
    z = a - b

    def closure(y):
        nonlocal x
        x += 1
        return x * y + z

    return closure(x)

example_args = (torch.randn(3, 2), torch.randn(2)) tags = {"python.closure"} model = NestedFunction()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, a: "f32[3, 2]", b: "f32[2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(a, b)

             sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(a, b);  a = b = None

             add_: "f32[3, 2]" = torch.ops.aten.add_.Tensor(add, 1);  add = None

             mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_, add_);  add_ = None
        add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub);  mul = sub = None
        return (add_1,)

Graph signature: # inputs a: USER_INPUT b: USER_INPUT

# outputs
add_1: USER_OUTPUT

Range constraints: {}

null_context_manager#

Original source code:

mypy: allow-untyped-defs

import contextlib

import torch

class NullContextManager(torch.nn.Module): """ Null context manager in Python will be traced out. """

def forward(self, x):
    """
    Null context manager in Python will be traced out.
    """
    ctx = contextlib.nullcontext()
    with ctx:
        return x.sin() + x.cos()

example_args = (torch.randn(3, 2),) tags = {"python.context-manager"} model = NullContextManager()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): sin: "f32[3, 2]" = torch.ops.aten.sin.default(x) cos: "f32[3, 2]" = torch.ops.aten.cos.default(x); x = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

pytree_flatten#

Note

Tags:

Support Level: SUPPORTED

Original source code:

mypy: allow-untyped-defs

import torch

from torch.utils import _pytree as pytree

class PytreeFlatten(torch.nn.Module): """ Pytree from PyTorch can be captured by TorchDynamo. """

def forward(self, x):
    y, _spec = pytree.tree_flatten(x)
    return y[0] + 1

example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), model = PytreeFlatten()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x_0_1: "f32[3, 2]", x_0_2: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x_0_1, 1); x_0_1 = None return (add,)

Graph signature: # inputs x_0_1: USER_INPUT x_0_2: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

scalar_output#

Original source code:

mypy: allow-untyped-defs

import torch

from torch.export import Dim

x = torch.randn(3, 2) dim1_x = Dim("dim1_x")

class ScalarOutput(torch.nn.Module): """ Returning scalar values from the graph is supported, in addition to Tensor outputs. Symbolic shapes are captured and rank is specialized. """ def init(self) -> None: super().init()

def forward(self, x):
    return x.shape[1] + 1

example_args = (x,) tags = {"torch.dynamic-shape"} dynamic_shapes = {"x": {1: dim1_x}} model = ScalarOutput()

torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, s27]"): # sym_size_int_1: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1); x = None

             add: "Sym(s27 + 1)" = sym_size_int_1 + 1;  sym_size_int_1 = None
        return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {s27: VR[0, int_oo]}

specialized_attribute#

Note

Tags:

Support Level: SUPPORTED

Original source code:

mypy: allow-untyped-defs

from enum import Enum

import torch

class Animal(Enum): COW = "moo"

class SpecializedAttribute(torch.nn.Module): """ Model attributes are specialized. """

def __init__(self) -> None:
    super().__init__()
    self.a = "moo"
    self.b = 4

def forward(self, x):
    if self.a == Animal.COW.value:
        return x * x + self.b
    else:
        raise ValueError("bad")

example_args = (torch.randn(3, 2),) model = SpecializedAttribute()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x); x = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4); mul = None return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

static_for_loop#

Original source code:

mypy: allow-untyped-defs

import torch

class StaticForLoop(torch.nn.Module): """ A for loop with constant number of iterations should be unrolled in the exported graph. """

def forward(self, x):
    # constant
    ret = [i + x for i in range(10)]
    return ret

example_args = (torch.randn(3, 2),) tags = {"python.control-flow"} model = StaticForLoop()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0) add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1) add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2) add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3) add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4) add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5) add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6) add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7) add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8) add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9); x = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT
add_1: USER_OUTPUT
add_2: USER_OUTPUT
add_3: USER_OUTPUT
add_4: USER_OUTPUT
add_5: USER_OUTPUT
add_6: USER_OUTPUT
add_7: USER_OUTPUT
add_8: USER_OUTPUT
add_9: USER_OUTPUT

Range constraints: {}

static_if#

Original source code:

mypy: allow-untyped-defs

import torch

class StaticIf(torch.nn.Module): """ if statement with static predicate value should be traced through with the taken branch. """

def forward(self, x):
    if len(x.shape) == 3:
        return x + torch.ones(1, 1, 1)

    return x

example_args = (torch.randn(3, 2, 2),) tags = {"python.control-flow"} model = StaticIf()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2, 2]"): ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False) add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones); x = ones = None return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

tensor_setattr#

Original source code:

mypy: allow-untyped-defs

import torch

class TensorSetattr(torch.nn.Module): """ setattr() call onto tensors is not supported. """ def forward(self, x, attr): setattr(x, attr, torch.randn(3, 2)) return x + 4

example_args = (torch.randn(3, 2), "attr") tags = {"python.builtin"} model = TensorSetattr()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", attr): randn: "f32[3, 2]" = torch.ops.aten.randn.default([3, 2], device = device(type='cpu'), pin_memory = False); randn = None

             add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4);  x = None
        return (add,)

Graph signature: # inputs x: USER_INPUT attr: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

type_reflection_method#

Original source code:

mypy: allow-untyped-defs

import torch

class A: @classmethod def func(cls, x): return 1 + x

class TypeReflectionMethod(torch.nn.Module): """ type() calls on custom objects followed by attribute accesses are not allowed due to its overly dynamic nature. """

def forward(self, x):
    a = A()
    return type(a).func(x)

example_args = (torch.randn(3, 4),) tags = {"python.builtin"} model = TypeReflectionMethod()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 4]"): add: "f32[3, 4]" = torch.ops.aten.add.Tensor(x, 1); x = None return (add,)

Graph signature: # inputs x: USER_INPUT

# outputs
add: USER_OUTPUT

Range constraints: {}

user_input_mutation#

Original source code:

mypy: allow-untyped-defs

import torch

class UserInputMutation(torch.nn.Module): """ Directly mutate user input in forward """

def forward(self, x):
    x.mul_(2)
    return x.cos()

example_args = (torch.randn(3, 2),) tags = {"torch.mutation"} model = UserInputMutation()

torch.export.export(model, example_args)

Result:

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): mul_: "f32[3, 2]" = torch.ops.aten.mul_.Tensor(x, 2); x = None

             cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul_);  mul_ = None
        return (cos,)

Graph signature: # inputs x: USER_INPUT

# outputs
cos: USER_OUTPUT

Range constraints: {}

Not Supported Yet#

dynamic_shape_round#

Original source code:

mypy: allow-untyped-defs

import torch

from torch._export.db.case import SupportLevel from torch.export import Dim

class DynamicShapeRound(torch.nn.Module): """ Calling round on dynamic shapes is not supported. """

def forward(self, x):
    return x[: round(x.shape[0] / 2)]

x = torch.randn(3, 2) dim0_x = Dim("dim0_x") example_args = (x,) tags = {"torch.dynamic-shape", "python.builtin"} support_level = SupportLevel.NOT_SUPPORTED_YET dynamic_shapes = {"x": {0: dim0_x}} model = DynamicShapeRound()

torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".

model_attr_mutation#

Original source code:

mypy: allow-untyped-defs

import torch from torch._export.db.case import SupportLevel

class ModelAttrMutation(torch.nn.Module): """ Attribute mutation is not supported. """

def __init__(self) -> None:
    super().__init__()
    self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]

def recreate_list(self):
    return [torch.zeros(3, 2), torch.zeros(3, 2)]

def forward(self, x):
    self.attr_list = self.recreate_list()
    return x.sum() + self.attr_list[0].sum()

example_args = (torch.randn(3, 2),) tags = {"python.object-model"} support_level = SupportLevel.NOT_SUPPORTED_YET model = ModelAttrMutation()

torch.export.export(model, example_args)

Result:

AssertionError: Mutating module attribute attr_list during export.

optional_input#

Original source code:

mypy: allow-untyped-defs

import torch from torch._export.db.case import SupportLevel

class OptionalInput(torch.nn.Module): """ Tracing through optional input is not supported yet """

def forward(self, x, y=torch.randn(2, 3)):
    if y is not None:
        return x + y
    return x

example_args = (torch.randn(2, 3),) tags = {"python.object-model"} support_level = SupportLevel.NOT_SUPPORTED_YET model = OptionalInput()

torch.export.export(model, example_args)

Result:

Unsupported: Tracing through optional input is not supported yet

unsupported_operator#

Original source code:

mypy: allow-untyped-defs

import torch from torch._export.db.case import SupportLevel

class TorchSymMin(torch.nn.Module): """ torch.sym_min operator is not supported in export. """

def forward(self, x):
    return x.sum() + torch.sym_min(x.size(0), 100)

example_args = (torch.randn(3, 2),) tags = {"torch.operator"} support_level = SupportLevel.NOT_SUPPORTED_YET model = TorchSymMin()

torch.export.export(model, example_args)

Result:

Unsupported: torch.* op returned non-Tensor