torch.export (original) (raw)

Overview#

torch.export.export() takes a torch.nn.Module and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized.

import torch from torch.export import export

class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: a = torch.sin(x) b = torch.cos(y) return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export( Mod(), args=example_args ) print(exported_program)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): # code: a = torch.sin(x) sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)

        # code: b = torch.cos(y)
        cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)

        # code: return a + b
        add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
        return (add,)

Graph signature:
    ExportGraphSignature(
        input_specs=[
            InputSpec(
                kind=<InputKind.USER_INPUT: 1>,
                arg=TensorArgument(name='x'),
                target=None,
                persistent=None
            ),
            InputSpec(
                kind=<InputKind.USER_INPUT: 1>,
                arg=TensorArgument(name='y'),
                target=None,
                persistent=None
            )
        ],
        output_specs=[
            OutputSpec(
                kind=<OutputKind.USER_OUTPUT: 1>,
                arg=TensorArgument(name='add'),
                target=None
            )
        ]
    )
Range constraints: {}

torch.export produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be foundhere.

Under the hood, torch.export leverages the following latest technologies:

Existing frameworks#

torch.compile() also utilizes the same PT2 stack as torch.export, but is slightly different:

Compared to torch.fx.symbolic_trace(), torch.export traces using TorchDynamo which operates at the Python bytecode level, giving it the ability to trace arbitrary Python constructs not limited by what Python operator overloading supports. Additionally, torch.export keeps fine-grained track of tensor metadata, so that conditionals on things like tensor shapes do not fail tracing. In general, torch.export is expected to work on more user programs, and produce lower-level graphs (at the torch.ops.aten operator level). Note that users can still use torch.fx.symbolic_trace() as a preprocessing step before torch.export.

Compared to torch.jit.script(), torch.export does not capture Python control flow or data structures, but it supports more Python language features than TorchScript (as it is easier to have comprehensive coverage over Python bytecodes). The resulting graphs are simpler and only have straight line control flow (except for explicit control flow operators).

Compared to torch.jit.trace(), torch.export is sound: it is able to trace code that performs integer computation on sizes and records all of the side-conditions necessary to show that a particular trace is valid for other inputs.

Exporting a PyTorch Model#

An Example#

The main entrypoint is through torch.export.export(), which takes a callable (torch.nn.Module, function, or method) and sample inputs, and captures the computation graph into an torch.export.ExportedProgram. An example:

import torch from torch.export import export

Simple module for demonstration

class M(torch.nn.Module): def init(self) -> None: super().init() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, padding=1 ) self.relu = torch.nn.ReLU() self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
    a = self.conv(x)
    a.add_(constant)
    return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),) example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export( M(), args=example_args, kwargs=example_kwargs ) print(exported_program)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): # code: a = self.conv(x) conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])

        # code: a.add_(constant)
        add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)

        # code: return self.maxpool(self.relu(a))
        relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
        max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
        return (max_pool2d,)

Graph signature: ExportGraphSignature( input_specs=[ InputSpec( kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None ), InputSpec( kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None ), InputSpec( kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None ), InputSpec( kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='constant'), target=None, persistent=None ) ], output_specs=[ OutputSpec( kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='max_pool2d'), target=None ) ] ) Range constraints: {}

Inspecting the ExportedProgram, we can note the following:

Non-Strict Export#

In PyTorch 2.3, we introduced a new mode of tracing called non-strict mode. It’s still going through hardening, so if you run into any issues, please file them to Github with the “oncall: export” tag.

In non-strict mode, we trace through the program using the Python interpreter. Your code will execute exactly as it would in eager mode; the only difference is that all Tensor objects will be replaced by ProxyTensors, which will record all their operations into a graph.

In strict mode, which is currently the default, we first trace through the program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not actually execute your Python code. Instead, it symbolically analyzes it and builds a graph based on the results. This analysis allows torch.export to provide stronger guarantees about safety, but not all Python code is supported.

An example of a case where one might want to use non-strict mode is if you run into a unsupported TorchDynamo feature that might not be easily solved, and you know the python code is not exactly needed for computation. For example:

import contextlib import torch

class ContextManager(): def init(self): self.count = 0 def enter(self): self.count += 1 def exit(self, exc_type, exc_value, traceback): self.count -= 1

class M(torch.nn.Module): def forward(self, x): with ContextManager(): return x.sin() + x.cos()

export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

In this example, the first call using non-strict mode (through thestrict=False flag) traces successfully whereas the second call using strict mode (default) results with a failure, where TorchDynamo is unable to support context managers. One option is to rewrite the code (see Limitations of torch.export), but seeing as the context manager does not affect the tensor computations in the model, we can go with the non-strict mode’s result.

Export for Training and Inference#

In PyTorch 2.5, we introduced a new API called export_for_training(). It’s still going through hardening, so if you run into any issues, please file them to Github with the “oncall: export” tag.

In this API, we produce the most generic IR that contains all ATen operators (including both functional and non-functional) which can be used to train in eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization and will soon be the default IR of torch.export.export. To read further about the motivation behind this change, please refer tohttps://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206

When this API is combined with run_decompositions(), you should be able to get inference IR with any desired decomposition behavior.

To show some examples:

class ConvBatchnorm(torch.nn.Module): def init(self) -> None: super().init() self.conv = torch.nn.Conv2d(1, 3, 1, 1) self.bn = torch.nn.BatchNorm2d(3)

def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return (x,)

mod = ConvBatchnorm() inp = torch.randn(1, 1, 3, 3)

ep_for_training = torch.export.export_for_training(mod, (inp,)) print(ep_for_training)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) return (batch_norm,)

From the above output, you can see that export_for_training() produces pretty much the same ExportedProgram as export() except for the operators in the graph. You can see that we captured batch_norm in the most general form. This op is non-functional and will be lowered to different ops when running inference.

You can also go from this IR to an inference IR via run_decompositions() with arbitrary customizations.

Lower to core aten inference IR, but keep conv2d

decomp_table = torch.export.default_decompositions() del decomp_table[torch.ops.aten.conv2d.default] ep_for_inference = ep_for_training.run_decompositions(decomp_table)

print(ep_for_inference)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] return (getitem_3, getitem_4, add, getitem)

Here you can see that we kept conv2d op in the IR while decomposing the rest. Now the IR is a functional IR containing core aten operators except for conv2d.

You can do even more customization by directly registering your chosen decomposition behaviors.

You can do even more customizations by directly registering custom decomp behaviour

Lower to core aten inference IR, but customize conv2d

decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function ep_for_inference = ep_for_training.run_decompositions(decomp_table)

print(ep_for_inference)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; return (getitem_3, getitem_4, add, getitem)

Expressing Dynamism#

By default torch.export will trace the program assuming all input shapes arestatic, and specializing the exported program to those dimensions. However, some dimensions, such as a batch dimension, can be dynamic and vary from run to run. Such dimensions must be specified by using thetorch.export.Dim() API to create them and by passing them intotorch.export.export() through the dynamic_shapes argument. An example:

import torch from torch.export import Dim, export

class M(torch.nn.Module): def init(self): super().init()

    self.branch1 = torch.nn.Sequential(
        torch.nn.Linear(64, 32), torch.nn.ReLU()
    )
    self.branch2 = torch.nn.Sequential(
        torch.nn.Linear(128, 64), torch.nn.ReLU()
    )
    self.buffer = torch.ones(32)

def forward(self, x1, x2):
    out1 = self.branch1(x1)
    out2 = self.branch2(x2)
    return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

Create a dynamic batch size

batch = Dim("batch")

Specify that the first dimension of each input is that batch size

dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export( M(), args=example_args, dynamic_shapes=dynamic_shapes ) print(exported_program)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):

     # code: out1 = self.branch1(x1)
    linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
    relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)

     # code: out2 = self.branch2(x2)
    linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
    relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)

     # code: return (out1 + self.buffer, out2)
    add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
    return (add, relu_1)

Range constraints: {s0: VR[0, int_oo]}

Some additional things to note:

We can also specify more expressive relationships between input shapes, such as where a pair of shapes might differ by one, a shape might be double of another, or a shape is even. An example:

class M(torch.nn.Module): def forward(self, x, y): return x + y[1:]

x, y = torch.randn(5), torch.randn(6) dimx = torch.export.Dim("dimx", min=3, max=6) dimy = dimx + 1

exported_program = torch.export.export( M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), ) print(exported_program)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): # code: return x + y[1:] slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) return (add,)

Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}

Some things to note:

Serialization#

To save the ExportedProgram, users can use the torch.export.save() andtorch.export.load() APIs. A convention is to save the ExportedProgramusing a .pt2 file extension.

An example:

import torch import io

class MyModule(torch.nn.Module): def forward(self, x): return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2') saved_exported_program = torch.export.load('exported_program.pt2')

Specializations#

A key concept in understanding the behavior of torch.export is the difference between static and dynamic values.

A dynamic value is one that can change from run to run. These behave like normal arguments to a Python function—you can pass different values for an argument and expect your function to do the right thing. Tensor data is treated as dynamic.

A static value is a value that is fixed at export time and cannot change between executions of the exported program. When the value is encountered during tracing, the exporter will treat it as a constant and hard-code it into the graph.

When an operation is performed (e.g. x + y) and all inputs are static, then the output of the operation will be directly hard-coded into the graph, and the operation won’t show up (i.e. it will get constant-folded).

When a value has been hard-coded into the graph, we say that the graph has been_specialized_ to that value.

The following values are static:

Input Tensor Shapes#

By default, torch.export will trace the program specializing on the input tensors’ shapes, unless a dimension is specified as dynamic via thedynamic_shapes argument to torch.export. This means that if there exists shape-dependent control flow, torch.export will specialize on the branch that is being taken with the given sample inputs. For example:

import torch from torch.export import export

class Mod(torch.nn.Module): def forward(self, x): if x.shape[0] > 5: return x + 1 else: return x - 1

example_inputs = (torch.rand(10, 2),) exported_program = export(Mod(), example_inputs) print(exported_program)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[10, 2]"): # code: return x + 1 add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) return (add,)

The conditional of (x.shape[0] > 5) does not appear in theExportedProgram because the example inputs have the static shape of (10, 2). Since torch.export specializes on the inputs’ static shapes, the else branch (x - 1) will never be reached. To preserve the dynamic branching behavior based on the shape of a tensor in the traced graph,torch.export.Dim() will need to be used to specify the dimension of the input tensor (x.shape[0]) to be dynamic, and the source code will need to be rewritten.

Note that tensors that are part of the module state (e.g. parameters and buffers) always have static shapes.

Python Primitives#

torch.export also specializes on Python primtivies, such as int, float, bool, and str. However they do have dynamic variants such as SymInt, SymFloat, and SymBool.

For example:

import torch from torch.export import export

class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, const: int, times: int): for i in range(times): x = x + const return x

example_inputs = (torch.rand(2, 2), 1, 3) exported_program = export(Mod(), example_inputs) print(exported_program)

ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[2, 2]", const, times): # code: x = x + const add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) return (add_2,)

Because integers are specialized, the torch.ops.aten.add.Tensor operations are all computed with the hard-coded constant 1, rather than const. If a user passes a different value for const at runtime, like 2, than the one used during export time, 1, this will result in an error. Additionally, the times iterator used in the for loop is also “inlined” in the graph through the 3 repeated torch.ops.aten.add.Tensor calls, and the input times is never used.

Python Containers#

Python containers (List, Dict, NamedTuple, etc.) are considered to have static structure.

Limitations of torch.export#

Graph Breaks#

As torch.export is a one-shot process for capturing a computation graph from a PyTorch program, it might ultimately run into untraceable parts of programs as it is nearly impossible to support tracing all PyTorch and Python features. In the case of torch.compile, an unsupported operation will cause a “graph break” and the unsupported operation will be run with default Python evaluation. In contrast, torch.export will require users to provide additional information or rewrite parts of their code to make it traceable. As the tracing is based on TorchDynamo, which evaluates at the Python bytecode level, there will be significantly fewer rewrites required compared to previous tracing frameworks.

When a graph break is encountered, ExportDB is a great resource for learning about the kinds of programs that are supported and unsupported, along with ways to rewrite programs to make them traceable.

An option to get past dealing with this graph breaks is by usingnon-strict export

Data/Shape-Dependent Control Flow#

Graph breaks can also be encountered on data-dependent control flow (if x.shape[0] > 2) when shapes are not being specialized, as a tracing compiler cannot possibly deal with without generating code for a combinatorially exploding number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support torch.condto express if-else like control flow (more coming soon!).

Read More#

API Reference#

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=())[source]#

export() takes any nn.Module along with example inputs, and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different inputs or serialized. The traced graph (1) produces normalized operators in the functional ATen operator set (as well as any user-specified custom operators), (2) has eliminated all Python control flow and data structures (with certain exceptions), and (3) records the set of shape constraints needed to show that this normalization and control-flow elimination is sound for future inputs.

Soundness Guarantee

While tracing, export() takes note of shape-related assumptions made by the user program and the underlying PyTorch operator kernels. The output ExportedProgram is considered valid only when these assumptions hold true.

Tracing makes assumptions on the shapes (not values) of input tensors. Such assumptions must be validated at graph capture time for export()to succeed. Specifically:

If any assumption can not be validated, a fatal error will be raised. When that happens, the error message will include suggested fixes to the specification that are needed to validate the assumptions. For example export() might suggest the following fix to the definition of a dynamic dimension dim0_x, say appearing in the shape associated with input x, that was previously defined as Dim("dim0_x"):

dim = Dim("dim0_x", max=5)

This example means the generated code requires dimension 0 of input x to be less than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension definitions and then copy them verbatim into your code without needing to change thedynamic_shapes argument to your export() call.

Parameters

Returns

An ExportedProgram containing the traced callable.

Return type

ExportedProgram

Acceptable input/output types

Acceptable types of inputs (for args and kwargs) and outputs include:

torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

Warning

Under active development, saved files may not be usable in newer versions of PyTorch.

Saves an ExportedProgram to a file-like object. It can then be loaded using the Python API torch.export.load.

Parameters

Example:

import torch import io

class MyModule(torch.nn.Module): def forward(self, x): return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

Save to file

torch.export.save(ep, 'exported_program.pt2')

Save to io.BytesIO buffer

buffer = io.BytesIO() torch.export.save(ep, buffer)

Save with extra files

extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)

torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#

Warning

Under active development, saved files may not be usable in newer versions of PyTorch.

Loads an ExportedProgram previously saved withtorch.export.save.

Parameters

Returns

An ExportedProgram object

Return type

ExportedProgram

Example:

import torch import io

Load ExportedProgram from file

ep = torch.export.load('exported_program.pt2')

Load ExportedProgram from io.BytesIO object

with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer)

Load with extra files.

extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))

torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#

Registers a dataclass as a valid input/output type for torch.export.export().

Parameters

Example:

import torch from dataclasses import dataclass

@dataclass class InputDataClass: feature: torch.Tensor bias: int

@dataclass class OutputDataClass: res: torch.Tensor

torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass)

class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res)

ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) print(ep)

class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#

Dim() constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type.

Parameters

Returns

A type that can be used in dynamic shape specifications for tensors.

torch.export.exported_program.default_decompositions()[source]#

This is the default decomposition table which contains decomposition of all ATEN operators to core aten opset. Use this API together withrun_decompositions()

Return type

CustomDecompTable

torch.export.dims(*names, min=None, max=None)[source]#

Util to create multiple Dim() types.

Returns

A tuple of Dim() types.

Return type

tuple[torch.export.dynamic_shapes.Dim, …]

class torch.export.dynamic_shapes.ShapesCollection[source]#

Builder for dynamic_shapes. Used to assign dynamic shape specifications to tensors that appear in inputs.

This is useful particularly when args() is a nested input structure, and it’s easier to index the input tensors, than to replicate the structure of args() in the dynamic_shapes() specification.

Example:

args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})

dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2}

This is equivalent to the following (now auto-generated):

dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

To specify dynamism for integers, we need to first wrap the integers using _IntWrapper so that we have a “unique identification tag” for each integer.

Example:

args = ({"x": tensor_x, "others": [int_x, int_y]})

Wrap all ints with _IntWrapper

mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)

dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC

This is equivalent to the following (now auto-generated):

dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

dynamic_shapes(m, args, kwargs=None)[source]#

Generates the dynamic_shapes() pytree structure according to args() and kwargs().

class torch.export.dynamic_shapes.AdditionalInputs[source]#

Infers dynamic_shapes based on additional inputs.

This is useful particularly for deployment engineers who, on the one hand, may have access to ample testing or profiling data that can provide a fair sense of representative inputs for a model, but on the other hand, may not know enough about the model to guess which input shapes should be dynamic.

Input shapes that are different than the original are considered dynamic; conversely, those that are the same as the original are considered static. Moreover, we verify that the additional inputs are valid for the exported program. This guarantees that tracing with them instead of the original would have generated the same graph.

Example:

args0, kwargs0 = ... # example inputs for export

other representative inputs that the exported program will run on

dynamic_shapes = torch.export.AdditionalInputs() dynamic_shapes.add(args1, kwargs1) ... dynamic_shapes.add(argsN, kwargsN)

torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)

add(args, kwargs=None)[source]#

Additional input args() and kwargs().

dynamic_shapes(m, args, kwargs=None)[source]#

Infers a dynamic_shapes() pytree structure by merging shapes of the original input args() and kwargs() and of each additional input args and kwargs.

verify(ep)[source]#

Verifies that an exported program is valid for each additional input.

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#

When exporting with dynamic_shapes(), export may fail with a ConstraintViolation error if the specification doesn’t match the constraints inferred from tracing the model. The error message may provide suggested fixes - changes that can be made to dynamic_shapes() to export successfully.

Example ConstraintViolation error message:

Suggested fixes:

dim = Dim('dim', min=3, max=6)  # this just refines the dim's range
dim = 4  # this specializes to a constant
dy = dx + 1  # dy was specified as an independent dim, but is actually tied to dx with this relation

This is a helper function that takes the ConstraintViolation error message and the original dynamic_shapes() spec, and returns a new dynamic_shapes() spec that incorporates the suggested fixes.

Example usage:

try: ep = export(mod, args, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) ep = export(mod, args, dynamic_shapes=new_shapes)

Return type

Union[dict[str, Any], tuple[Any], list[Any]]

torch.export.Constraint#

alias of Union[_Constraint, _DerivedConstraint, _RelaxedConstraint]

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#

Package of a program from export(). It contains an torch.fx.Graph that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata.

You can call an ExportedProgram like the original callable traced byexport() with the same calling convention.

To perform transformations on the graph, use .module property to access an torch.fx.GraphModule. You can then useFX transformationto rewrite the graph. Afterwards, you can simply use export()again to construct a correct ExportedProgram.

module()[source]#

Returns a self contained GraphModule with all the parameters/buffers inlined.

Return type

Module

buffers()[source]#

Returns an iterator over original module buffers.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Tensor]

named_buffers()[source]#

Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[tuple[str, torch.Tensor]]

parameters()[source]#

Returns an iterator over original module’s parameters.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Parameter]

named_parameters()[source]#

Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[tuple[str, torch.nn.parameter.Parameter]]

run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#

Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in theCore ATen Operator Set.

For now, we do not decompose joint graphs.

Parameters

decomp_table (Optional_[_dict_[_ _torch._ops.OperatorBase_ _,_ Callable] ]) – An optional argument that specifies decomp behaviour for Aten ops (1) If None, we decompose to core aten decompositions (2) If empty, we don’t decompose any operator

Return type

ExportedProgram

Some examples:

If you don’t want to decompose anything

ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})

If you want to get a core aten operator set except for certain operator, you can do following:

ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)

class torch.export.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[source]#

class torch.export.ExportGraphSignature(input_specs, output_specs)[source]#

ExportGraphSignature models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees.

Export Graph is functional and does not access “states” like parameters or buffers within the graph via getattr nodes. Instead, export()gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.

The ordering of all inputs and outputs are:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]

e.g. If following module is exported:

class CustomModule(nn.Module): def init(self) -> None: super(CustomModule, self).init()

    # Define a parameter
    self.my_parameter = nn.Parameter(torch.tensor(2.0))

    # Define two buffers
    self.register_buffer('my_buffer1', torch.tensor(3.0))
    self.register_buffer('my_buffer2', torch.tensor(4.0))

def forward(self, x1, x2):
    # Use the parameter, buffers, and both inputs in the forward method
    output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

    # Mutate one of the buffers (e.g., increment it by 1)
    self.my_buffer2.add_(1.0) # In-place addition

    return output

Resulting Graph would be:

graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)

Resulting ExportGraphSignature would be:

ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )

class torch.export.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[source]#

class torch.export.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source]#

class torch.export.decomp_utils.CustomDecompTable[source]#

This is a custom dictionary that is specifically used for handling decomp_table in export. The reason we need this is because in the new world, you can only delete an op from decomp table to preserve it. This is problematic for custom ops because we don’t know when the custom op will actually be loaded to the dispatcher. As a result, we need to record the custom ops operations until we really need to materialize it (which is when we run decomposition pass.)

Invariants we hold are:

  1. All aten decomp is loaded at the init time
  2. We materialize ALL ops when user ever reads from the table to make it more likely that dispatcher picks up the custom op.
  3. If it is write operation, we don’t necessarily materialize
  4. We load the final time during export, right before calling run_decompositions()

copy()[source]#

Return type

CustomDecompTable

items()[source]#

keys()[source]#

materialize()[source]#

Return type

dict[torch._ops.OperatorBase, Callable]

pop(*args)[source]#

update(other_dict)[source]#

class torch.export.graph_signature.InputKind(value)[source]#

An enumeration.

class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[source]#

class torch.export.graph_signature.OutputKind(value)[source]#

An enumeration.

class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source]#

class torch.export.graph_signature.SymIntArgument(name: str)[source]#

class torch.export.graph_signature.SymBoolArgument(name: str)[source]#

class torch.export.graph_signature.SymFloatArgument(name: str)[source]#

class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]#

ExportGraphSignature models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees.

Export Graph is functional and does not access “states” like parameters or buffers within the graph via getattr nodes. Instead, export()gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.

The ordering of all inputs and outputs are:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]

e.g. If following module is exported:

class CustomModule(nn.Module): def init(self) -> None: super(CustomModule, self).init()

    # Define a parameter
    self.my_parameter = nn.Parameter(torch.tensor(2.0))

    # Define two buffers
    self.register_buffer('my_buffer1', torch.tensor(3.0))
    self.register_buffer('my_buffer2', torch.tensor(4.0))

def forward(self, x1, x2):
    # Use the parameter, buffers, and both inputs in the forward method
    output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

    # Mutate one of the buffers (e.g., increment it by 1)
    self.my_buffer2.add_(1.0) # In-place addition

    return output

Resulting Graph would be:

graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)

Resulting ExportGraphSignature would be:

ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )

replace_all_uses(old, new)[source]#

Replace all uses of the old name with new name in the signature.

get_replace_hook(replace_inputs=False)[source]#

class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source]#

class torch.export.unflatten.FlatArgsAdapter[source]#

Adapts input arguments with input_spec to align target_spec.

abstract adapt(target_spec, input_spec, input_args, metadata=None, obj=None)[source]#

NOTE: This adapter may mutate given input_args_with_path.

Return type

list[Any]

class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#

A module that uses torch.fx.Interpreter to execute instead of the usual codegen that GraphModule uses. This provides better stack trace information and makes it easier to debug execution.

class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#

A module that carries a sequence of InterpreterModules corresponding to a sequence of calls of that module. Each call to the module dispatches to the next InterpreterModule, and wraps back around after the last.

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#

Unflatten an ExportedProgram, producing a module with the same module hierarchy as the original eager module. This can be useful if you are trying to use torch.export with another system that expects a module hierachy instead of the flat graph that torch.export usually produces.

Note

The args/kwargs of unflattened modules will not necessarily match the eager module, so doing a module swap (e.g. self.submod = new_mod) will not necessarily work. If you need to swap a module out, you need to set the preserve_module_call_signature parameter oftorch.export.export().

Parameters

Returns

An instance of UnflattenedModule, which has the same module hierarchy as the original eager module pre-export.

Return type

UnflattenedModule

torch.export.passes.move_to_device_pass(ep, location)[source]#

Move the exported program to the given device.

Parameters

Returns

The moved exported program.

Return type

ExportedProgram