torch.export IR Specification — PyTorch 2.7 documentation (original) (raw)

Export IR is an intermediate representation (IR) for compilers, which bears similarities to MLIR and TorchScript. It is specifically designed to express the semantics of PyTorch programs. Export IR primarily represents computation in a streamlined list of operations, with limited support for dynamism such as control flows.

To create an Export IR graph, a frontend can be used that soundly captures a PyTorch program via a trace-specializing mechanism. The resulting Export IR can then be optimized and executed by a backend. This can be done today throughtorch.export.export().

The key concepts that will be covered in this document include:

Assumptions

This doc assumes that the audience is sufficiently familiar with PyTorch, specifically with torch.fx and its related toolings. Thus it will stop describing contents present in torch.fx documentation and paper.

What is Export IR

Export IR is a graph-based intermediate representation IR of PyTorch programs. Export IR is realized on top of torch.fx.Graph. In other words, all Export IR graphs are also valid FX graphs, and if interpreted using standard FX semantics, Export IR can be interpreted soundly. One implication is that an exported graph can be converted to a valid Python program via standard FX codegen.

This documentation will primarily focus on highlighting areas where Export IR differs from FX in terms of its strictness, while skipping parts where it shares similarities with FX.

ExportedProgram

The top-level Export IR construct is an torch.export.ExportedProgramclass. It bundles the computational graph of a PyTorch model (which is usually atorch.nn.Module) with the parameters or weights that this model consumes.

Some notable attributes of the torch.export.ExportedProgram class are:

Graph

An Export IR Graph is a PyTorch program represented in the form of a DAG (directed acyclic graph). Each node in this graph represents a particular computation or operation, and edges of this graph consist of references between nodes.

We can view Graph having this schema:

class Graph: nodes: List[Node]

In practice, Export IR’s graph is realized as torch.fx.Graph Python class.

An Export IR graph contains the following nodes (Nodes will be described in more details in the next section):

Collorary: The smallest valid Graph will be of one node. i.e. nodes is never empty.

**Definition:**The set of placeholder nodes of a Graph represents the inputs of the Graph of GraphModule. The output node of a Graph represents the outputsof the Graph of GraphModule.

Example:

import torch from torch import nn

class MyModule(nn.Module):

def forward(self, x, y):
  return x + y

example_args = (torch.randn(1), torch.randn(1)) mod = torch.export.export(MyModule(), example_args) print(mod.graph)

graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) return (add,)

The above is the textual representation of a Graph, with each line being a node.

Node

A Node represents a particular computation or operation and is represented in Python using the torch.fx.Node class. Edges between nodes are represented as direct references to other nodes via the args property of the Node class. Using the same FX machinery, we can represent the following operations that a computational graph typically needs, such as operator calls, placeholders (aka inputs), conditionals, and loops.

The Node has the following schema:

class Node: name: str # name of node op_name: str # type of operation

interpretation of the fields below depends on op_name

target: [str|Callable] args: List[object] kwargs: Dict[str, object] meta: Dict[str, object]

FX Text Format

As in the example above, notice that each line has this format:

%:[...] = [target=](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

This format captures everything present in the Node class, with the exception ofmeta, in a compact format.

Concretely:

For example, a call to the add operator would appear as:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

Where %x, %y are two other Nodes that have names x and y. Worth noting that the string torch.op.aten.add.Tensor represents the callable object that is actually stored in the target field, not merely its string name.

The final line of this text format is:

which is a Node with op_name = output, indicating that we are returning this one element.

call_function

A call_function node represents a call to an operator.

Definitions

Representation in FX

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

Differences from vanilla FX call_function

  1. In FX graph, a call_function can refer to any callable, in Export IR, we restrict it to only a select subset of ATen operators, custom operators, and control flow operators.
  2. In Export IR, constant arguments will be embedded within the graph.
  3. In FX graph, a get_attr node can represent reading any attribute stored in the graph module. However, in Export IR this is restricted to reading only submodules as all parameters/buffers will be passed in as inputs to the graph module.

placeholder

Placeholder represents an input to a graph. Its semantics are exactly the same as in FX. Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero.

Representation in FX

%name = placeholder[target = name](args = ())

The target field is a string which is the name of input.

args, if non-empty, should be of size 1 representing the default value of this input.

Metadata

Placeholder nodes also have meta[‘val’], like call_function nodes. Theval field in this case represents the input shape/dtype that the graph is expected to receive for this input parameter.

output

An output call represents a return statement in a function; it thus terminates the current graph. There is one and only one output node, and it will always be the last node of the graph.

Representation in FX

output[](args = (%something, …))

This has the exact semantics as in torch.fx. args represents the node to be returned.

Metadata

Output node has the same metadata as call_function nodes.

get_attr

get_attr nodes represent reading a submodule from the encapsulatingtorch.fx.GraphModule. Unlike a vanilla FX graph fromtorch.fx.symbolic_trace() in which get_attr nodes are used to read attributes such as parameters and buffers from the top-leveltorch.fx.GraphModule, parameters and buffers are passed in as inputs to the graph module, and stored in the top-leveltorch.export.ExportedProgram.

Representation in FX

%name = get_attr[target = name](args = ())

Example

Consider the following model:

from functorch.experimental.control_flow import cond

def true_fn(x): return x.sin()

def false_fn(x): return x.cos()

def f(x, y): return cond(y, true_fn, false_fn, [x])

Graph:

graph(): %x_1 : [num_users=1] = placeholder[target=x_1] %y_1 : [num_users=1] = placeholder[target=y_1] %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {}) return conditional

The line, %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0], reads the submodule true_graph_0 which contains the sin operator.

References

SymInt

A SymInt is an object that can either be a literal integer or a symbol that represents an Integer (represented in Python by sympy.Symbol class). When SymInt is a symbol, it describes a variable of type integer that is unknown to the graph at compile time, that is, its value is only known at runtime.

FakeTensor

A FakeTensor is an object that contains the metadata of a tensor. It can be viewed as having the following metadata.

class FakeTensor: size: List[SymInt] dtype: torch.dtype device: torch.device dim_order: List[int] # This doesn't exist yet

The size field of FakeTensor is a list of integers or SymInts. If SymInts are present, this means this tensor has a dynamic shape. If integers are present, it is assumed that the tensor will have that exact static shape. The rank of the TensorMeta is never dynamic. The dtype field represents the dtype of the output of that node. There are no implicit type promotions in Edge IR. There are no strides in FakeTensor.

In other words:

For example:

Python code:

def add_one(x): return torch.ops.aten(x, 1)

Graph:

graph(): %ph_0 : [#users=1] = placeholder[target=ph_0] %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {}) return [add_tensor]

FakeTensor:

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able Types

We define a type “Pytree-able”, if it is either a leaf type or a container type that contains other Pytree-able types.

Note:

The concept of pytree is the same as the one documentedhere for JAX:

The following types are defined as leaf type:

Type Definition
Tensor torch.Tensor
Scalar Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors.
int Python int (bound as int64_t in C++)
float Python float (bound as double in C++)
bool Python bool
str Python string
ScalarType torch.dtype
Layout torch.layout
MemoryFormat torch.memory_format
Device torch.device

The following types are defined as container type:

Type Definition
Tuple Python tuple
List Python list
Dict Python dict with Scalar keys
NamedTuple Python namedtuple
Dataclass Must be registered through register_dataclass
Custom class Any custom class defined with _register_pytree_node