torch.export API Reference (original) (raw)

Created On: Jul 17, 2025 | Last Updated On: Jul 17, 2025

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=(), prefer_deferred_runtime_asserts_over_guards=False)[source]#

export() takes any nn.Module along with example inputs, and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different inputs or serialized. The traced graph (1) produces normalized operators in the functional ATen operator set (as well as any user-specified custom operators), (2) has eliminated all Python control flow and data structures (with certain exceptions), and (3) records the set of shape constraints needed to show that this normalization and control-flow elimination is sound for future inputs.

Soundness Guarantee

While tracing, export() takes note of shape-related assumptions made by the user program and the underlying PyTorch operator kernels. The output ExportedProgram is considered valid only when these assumptions hold true.

Tracing makes assumptions on the shapes (not values) of input tensors. Such assumptions must be validated at graph capture time for export()to succeed. Specifically:

If any assumption can not be validated, a fatal error will be raised. When that happens, the error message will include suggested fixes to the specification that are needed to validate the assumptions. For example export() might suggest the following fix to the definition of a dynamic dimension dim0_x, say appearing in the shape associated with input x, that was previously defined as Dim("dim0_x"):

dim = Dim("dim0_x", max=5)

This example means the generated code requires dimension 0 of input x to be less than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension definitions and then copy them verbatim into your code without needing to change thedynamic_shapes argument to your export() call.

Parameters

Returns

An ExportedProgram containing the traced callable.

Return type

ExportedProgram

Acceptable input/output types

Acceptable types of inputs (for args and kwargs) and outputs include:

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#

Package of a program from export(). It contains an torch.fx.Graph that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata.

You can call an ExportedProgram like the original callable traced byexport() with the same calling convention.

To perform transformations on the graph, use .module property to access an torch.fx.GraphModule. You can then useFX transformationto rewrite the graph. Afterwards, you can simply use export()again to construct a correct ExportedProgram.

buffers()[source]#

Returns an iterator over original module buffers.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Tensor]

property call_spec#

Warning

This API is experimental and is NOT backward-compatible.

property constants#

Warning

This API is experimental and is NOT backward-compatible.

property dialect_: str_#

Warning

This API is experimental and is NOT backward-compatible.

property example_inputs#

Warning

This API is experimental and is NOT backward-compatible.

property graph#

Warning

This API is experimental and is NOT backward-compatible.

property graph_module#

Warning

This API is experimental and is NOT backward-compatible.

property graph_signature#

Warning

This API is experimental and is NOT backward-compatible.

module(check_guards=True)[source]#

Returns a self contained GraphModule with all the parameters/buffers inlined.

Return type

GraphModule

property module_call_graph#

Warning

This API is experimental and is NOT backward-compatible.

named_buffers()[source]#

Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[tuple[str, torch.Tensor]]

named_parameters()[source]#

Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[tuple[str, torch.nn.parameter.Parameter]]

parameters()[source]#

Returns an iterator over original module’s parameters.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Parameter]

property range_constraints#

Warning

This API is experimental and is NOT backward-compatible.

run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#

Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in theCore ATen Operator Set.

For now, we do not decompose joint graphs.

Parameters

decomp_table (Optional_[_dict_[_ _torch._ops.OperatorBase_ _,_ Callable] ]) – An optional argument that specifies decomp behaviour for Aten ops (1) If None, we decompose to core aten decompositions (2) If empty, we don’t decompose any operator

Return type

ExportedProgram

Some examples:

If you don’t want to decompose anything

ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})

If you want to get a core aten operator set except for certain operator, you can do following:

ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)

property state_dict#

Warning

This API is experimental and is NOT backward-compatible.

property tensor_constants#

Warning

This API is experimental and is NOT backward-compatible.

validate()[source]#

Warning

This API is experimental and is NOT backward-compatible.

property verifier_: Any_#

Warning

This API is experimental and is NOT backward-compatible.

property verifiers#

Warning

This API is experimental and is NOT backward-compatible.

class torch.export.dynamic_shapes.AdditionalInputs[source]#

Infers dynamic_shapes based on additional inputs.

This is useful particularly for deployment engineers who, on the one hand, may have access to ample testing or profiling data that can provide a fair sense of representative inputs for a model, but on the other hand, may not know enough about the model to guess which input shapes should be dynamic.

Input shapes that are different than the original are considered dynamic; conversely, those that are the same as the original are considered static. Moreover, we verify that the additional inputs are valid for the exported program. This guarantees that tracing with them instead of the original would have generated the same graph.

Example:

args0, kwargs0 = ... # example inputs for export

other representative inputs that the exported program will run on

dynamic_shapes = torch.export.AdditionalInputs() dynamic_shapes.add(args1, kwargs1) ... dynamic_shapes.add(argsN, kwargsN)

torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)

add(args, kwargs=None)[source]#

Additional input args() and kwargs().

dynamic_shapes(m, args, kwargs=None)[source]#

Infers a dynamic_shapes() pytree structure by merging shapes of the original input args() and kwargs() and of each additional input args and kwargs.

verify(ep)[source]#

Verifies that an exported program is valid for each additional input.

class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#

The Dim class allows users to specify dynamism in their exported programs. By marking a dimension with a Dim, the compiler associates the dimension with a symbolic integer containing a dynamic range.

The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes:Dim.AUTO, Dim.DYNAMIC, Dim.STATIC), or named Dims (i.e.Dim("name", min=1, max=2)).

Dim hints provide the lowest barrier to exportability, with the user only needing to specify if a dimension if dynamic, static, or left for the compiler to decide (Dim.AUTO). The export process will automatically infer the remaining constraints on min/max ranges and relationships between dimensions.

Example:

class Foo(nn.Module): def forward(self, x, y): assert x.shape[0] == 4 assert y.shape[0] >= 16 return x @ y

x = torch.randn(4, 8) y = torch.randn(8, 16) dynamic_shapes = { "x": {0: Dim.AUTO, 1: Dim.AUTO}, "y": {0: Dim.AUTO, 1: Dim.AUTO}, } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

Here, export would raise an exception if we replaced all uses of Dim.AUTO with Dim.DYNAMIC, as x.shape[0] is constrained to be static by the model.

More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, e.g. (x.shape[0] + y.shape[1]) % 4 == 0, to be raised if runtime inputs do not satisfy such constraints.

You may also specify min-max bounds for Dim hints, e.g. Dim.AUTO(min=16, max=32), Dim.DYNAMIC(max=64), with the compiler inferring the remaining constraints within the ranges. An exception will be raised if the valid range is entirely outside the user-specified range.

Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler infers constraints that do not match the user specification. For example, exporting the previous model, the user would need the following dynamic_shapes argument:

s0 = Dim("s0") s1 = Dim("s1", min=16) dynamic_shapes = { "x": {0: 4, 1: s0}, "y": {0: s0, 1: s1}, } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. For example, the following indicates one dimension is a multiple of another plus 4:

s0 = Dim("s0") s1 = 3 * s0 + 4

class torch.export.dynamic_shapes.ShapesCollection[source]#

Builder for dynamic_shapes. Used to assign dynamic shape specifications to tensors that appear in inputs.

This is useful particularly when args() is a nested input structure, and it’s easier to index the input tensors, than to replicate the structure of args() in the dynamic_shapes() specification.

Example:

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2}

This is equivalent to the following (now auto-generated):

dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

To specify dynamism for integers, we need to first wrap the integers using _IntWrapper so that we have a “unique identification tag” for each integer.

Example:

args = {"x": tensor_x, "others": [int_x, int_y]}

Wrap all ints with _IntWrapper

mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)

dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC

This is equivalent to the following (now auto-generated):

dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

dynamic_shapes(m, args, kwargs=None)[source]#

Generates the dynamic_shapes() pytree structure according to args() and kwargs().

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#

When exporting with dynamic_shapes(), export may fail with a ConstraintViolation error if the specification doesn’t match the constraints inferred from tracing the model. The error message may provide suggested fixes - changes that can be made to dynamic_shapes() to export successfully.

Example ConstraintViolation error message:

Suggested fixes:

dim = Dim('dim', min=3, max=6)  # this just refines the dim's range
dim = 4  # this specializes to a constant
dy = dx + 1  # dy was specified as an independent dim, but is actually tied to dx with this relation

This is a helper function that takes the ConstraintViolation error message and the original dynamic_shapes() spec, and returns a new dynamic_shapes() spec that incorporates the suggested fixes.

Example usage:

try: ep = export(mod, args, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) ep = export(mod, args, dynamic_shapes=new_shapes)

Return type

Union[dict[str, Any], tuple[Any], list[Any]]

torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

Warning

Under active development, saved files may not be usable in newer versions of PyTorch.

Saves an ExportedProgram to a file-like object. It can then be loaded using the Python API torch.export.load.

Parameters

Example:

import torch import io

class MyModule(torch.nn.Module): def forward(self, x): return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

Save to file

torch.export.save(ep, "exported_program.pt2")

Save to io.BytesIO buffer

buffer = io.BytesIO() torch.export.save(ep, buffer)

Save with extra files

extra_files = {"foo.txt": b"bar".decode("utf-8")} torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)

torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#

Warning

Under active development, saved files may not be usable in newer versions of PyTorch.

Loads an ExportedProgram previously saved withtorch.export.save.

Parameters

Returns

An ExportedProgram object

Return type

ExportedProgram

Example:

import torch import io

Load ExportedProgram from file

ep = torch.export.load("exported_program.pt2")

Load ExportedProgram from io.BytesIO object

with open("exported_program.pt2", "rb") as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer)

Load with extra files.

extra_files = {"foo.txt": ""} # values will be replaced with data ep = torch.export.load("exported_program.pt2", extra_files=extra_files) print(extra_files["foo.txt"]) print(ep(torch.randn(5)))

torch.export.pt2_archive._package.package_pt2(f, *, exported_programs=None, aoti_files=None, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

Saves the artifacts to a PT2Archive format. The artifact can then be loaded using load_pt2.

Parameters

Return type

Union[str, PathLike[str], IO[bytes]]

torch.export.pt2_archive._package.load_pt2(f, *, expected_opset_version=None, run_single_threaded=False, num_runners=1, device_index=-1, load_weights_from_disk=False)[source]#

Loads all the artifacts previously saved with package_pt2.

Parameters

Returns

A PT2ArchiveContents object which contains all the objects in the PT2.

Return type

PT2ArchiveContents

torch.export.draft_export(mod, args, kwargs=None, *, dynamic_shapes=None, preserve_module_call_signature=(), strict=False, prefer_deferred_runtime_asserts_over_guards=False)[source]#

A version of torch.export.export which is designed to consistently produce an ExportedProgram, even if there are potential soundness issues, and to generate a report listing the issues found.

Return type

ExportedProgram

class torch.export.unflatten.FlatArgsAdapter[source]#

Adapts input arguments with input_spec to align target_spec.

abstract adapt(target_spec, input_spec, input_args, metadata=None, obj=None)[source]#

NOTE: This adapter may mutate given input_args_with_path.

Return type

list[Any]

get_flat_arg_paths()[source]#

Returns a list of paths that are used to access the flat args.

Return type

list[str]

class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#

A module that uses torch.fx.Interpreter to execute instead of the usual codegen that GraphModule uses. This provides better stack trace information and makes it easier to debug execution.

class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#

A module that carries a sequence of InterpreterModules corresponding to a sequence of calls of that module. Each call to the module dispatches to the next InterpreterModule, and wraps back around after the last.

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#

Unflatten an ExportedProgram, producing a module with the same module hierarchy as the original eager module. This can be useful if you are trying to use torch.export with another system that expects a module hierarchy instead of the flat graph that torch.export usually produces.

Note

The args/kwargs of unflattened modules will not necessarily match the eager module, so doing a module swap (e.g. self.submod = new_mod) will not necessarily work. If you need to swap a module out, you need to set the preserve_module_call_signature parameter oftorch.export.export().

Parameters

Returns

An instance of UnflattenedModule, which has the same module hierarchy as the original eager module pre-export.

Return type

UnflattenedModule

torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#

Registers a dataclass as a valid input/output type for torch.export.export().

Parameters

Example:

import torch from dataclasses import dataclass

@dataclass class InputDataClass: feature: torch.Tensor bias: int

@dataclass class OutputDataClass: res: torch.Tensor

torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass)

class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res)

ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) print(ep)

class torch.export.decomp_utils.CustomDecompTable[source]#

This is a custom dictionary that is specifically used for handling decomp_table in export. The reason we need this is because in the new world, you can only delete an op from decomp table to preserve it. This is problematic for custom ops because we don’t know when the custom op will actually be loaded to the dispatcher. As a result, we need to record the custom ops operations until we really need to materialize it (which is when we run decomposition pass.)

Invariants we hold are:

  1. All aten decomp is loaded at the init time
  2. We materialize ALL ops when user ever reads from the table to make it more likely that dispatcher picks up the custom op.
  3. If it is write operation, we don’t necessarily materialize
  4. We load the final time during export, right before calling run_decompositions()

copy()[source]#

Return type

CustomDecompTable

items()[source]#

keys()[source]#

materialize()[source]#

Return type

dict[torch._ops.OperatorBase, Callable]

pop(*args)[source]#

update(other_dict)[source]#

torch.export.passes.move_to_device_pass(ep, location)[source]#

Move the exported program to the given device.

Parameters

Returns

The moved exported program.

Return type

ExportedProgram

class torch.export.pt2_archive.PT2ArchiveReader(archive_path_or_buffer)#

Context manager for reading a PT2 archive.

archive_version()[source]#

Get the archive version.

Return type

int

get_file_names()[source]#

Get the file names in the archive.

Return type

list[str]

read_bytes(name)[source]#

Read a bytes object from the archive. name: The source file inside the archive.

Return type

bytes

read_string(name)[source]#

Read a string object from the archive. name: The source file inside the archive.

Return type

str

class torch.export.pt2_archive.PT2ArchiveWriter(archive_path_or_buffer)#

Context manager for writing a PT2 archive.

close()[source]#

Close the archive.

count_prefix(prefix)[source]#

Count the number of records that start with a given prefix.

Return type

int

has_record(name)[source]#

Check if a record exists in the archive.

Return type

bool

write_bytes(name, data)[source]#

Write a bytes object to the archive. name: The destination file inside the archive. data: The bytes object to write.

write_file(name, file_path)[source]#

Copy a file into the archive. name: The destination file inside the archive. file_path: The source file on disk.

write_folder(archive_dir, folder_dir)[source]#

Copy a folder into the archive. archive_dir: The destination folder inside the archive. folder_dir: The source folder on disk.

write_string(name, data)[source]#

Write a string object to the archive. name: The destination file inside the archive. data: The string object to write.

torch.export.pt2_archive.is_pt2_package(serialized_model)[source]#

Check if the serialized model is a PT2 Archive package.

Return type

bool

class torch.export.exported_program.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source]#

class torch.export.exported_program.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[source]#

torch.export.exported_program.default_decompositions()[source]#

This is the default decomposition table which contains decomposition of all ATEN operators to core aten opset. Use this API together withrun_decompositions()

Return type

CustomDecompTable

class torch.export.custom_obj.ScriptObjectMeta(constant_name, class_fqn)[source]#

Metadata which is stored on nodes representing ScriptObjects.

class torch.export.graph_signature.ConstantArgument(name: str, value: Union[int, float, bool, str, NoneType])[source]#

name_: str_#

value_: Optional[Union[int, float, bool, str]]_#

class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source]#

class_fqn_: str_#

fake_val_: Optional[FakeScriptObject]_ = None#

name_: str_#

class torch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[source]#

gradients_to_parameters_: dict[str, str]_#

gradients_to_user_inputs_: dict[str, str]_#

loss_output_: str_#

class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]#

ExportGraphSignature models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants guarantees.

Export Graph is functional and does not access “states” like parameters or buffers within the graph via getattr nodes. Instead, export()guarantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.

The ordering of all inputs and outputs are:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]

e.g. If following module is exported:

class CustomModule(nn.Module): def init(self) -> None: super(CustomModule, self).init()

    # Define a parameter
    self.my_parameter = nn.Parameter(torch.tensor(2.0))

    # Define two buffers
    self.register_buffer("my_buffer1", torch.tensor(3.0))
    self.register_buffer("my_buffer2", torch.tensor(4.0))

def forward(self, x1, x2):
    # Use the parameter, buffers, and both inputs in the forward method
    output = (
        x1 + self.my_parameter
    ) * self.my_buffer1 + x2 * self.my_buffer2

    # Mutate one of the buffers (e.g., increment it by 1)
    self.my_buffer2.add_(1.0)  # In-place addition

    return output

mod = CustomModule() ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))

Resulting Graph is non-functional:

graph(): %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] %x1 : [num_users=1] = placeholder[target=x1] %x2 : [num_users=1] = placeholder[target=x2] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) return (add_1,)

Resulting ExportGraphSignature of the non-functional Graph would be:

inputs

p_my_parameter: PARAMETER target='my_parameter' b_my_buffer1: BUFFER target='my_buffer1' persistent=True b_my_buffer2: BUFFER target='my_buffer2' persistent=True x1: USER_INPUT x2: USER_INPUT

outputs

add_1: USER_OUTPUT

To get a functional Graph, you can use run_decompositions():

mod = CustomModule() ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) ep = ep.run_decompositions()

Resulting Graph is functional:

graph(): %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] %x1 : [num_users=1] = placeholder[target=x1] %x2 : [num_users=1] = placeholder[target=x2] %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) return (add_2, add_1)

Resulting ExportGraphSignature of the functional Graph would be:

inputs

p_my_parameter: PARAMETER target='my_parameter' b_my_buffer1: BUFFER target='my_buffer1' persistent=True b_my_buffer2: BUFFER target='my_buffer2' persistent=True x1: USER_INPUT x2: USER_INPUT

outputs

add_2: BUFFER_MUTATION target='my_buffer2' add_1: USER_OUTPUT

property assertion_dep_token_: Optional[Mapping[int, str]]_#

property backward_signature_: Optional[ExportBackwardSignature]_#

property buffers_: Collection[str]_#

property buffers_to_mutate_: Mapping[str, str]_#

get_replace_hook(replace_inputs=False)[source]#

input_specs_: list[torch.export.graph_signature.InputSpec]_#

property input_tokens_: Collection[str]_#

property inputs_to_buffers_: Mapping[str, str]_#

property inputs_to_lifted_custom_objs_: Mapping[str, str]_#

property inputs_to_lifted_tensor_constants_: Mapping[str, str]_#

property inputs_to_parameters_: Mapping[str, str]_#

property lifted_custom_objs_: Collection[str]_#

property lifted_tensor_constants_: Collection[str]_#

property non_persistent_buffers_: Collection[str]_#

output_specs_: list[torch.export.graph_signature.OutputSpec]_#

property output_tokens_: Collection[str]_#

property parameters_: Collection[str]_#

property parameters_to_mutate_: Mapping[str, str]_#

replace_all_uses(old, new)[source]#

Replace all uses of the old name with new name in the signature.

property user_inputs_: Collection[Union[int, float, bool, None, str]]_#

property user_inputs_to_mutate_: Mapping[str, str]_#

property user_outputs_: Collection[Union[int, float, bool, None, str]]_#

class torch.export.graph_signature.InputKind(value)[source]#

An enumeration.

BUFFER = 3#

CONSTANT_TENSOR = 4#

CUSTOM_OBJ = 5#

PARAMETER = 2#

TOKEN = 6#

USER_INPUT = 1#

class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[source]#

arg_: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]_#

kind_: InputKind_#

persistent_: Optional[bool]_ = None#

target_: Optional[str]_#

class torch.export.graph_signature.OutputKind(value)[source]#

An enumeration.

BUFFER_MUTATION = 3#

GRADIENT_TO_PARAMETER = 5#

GRADIENT_TO_USER_INPUT = 6#

LOSS_OUTPUT = 2#

PARAMETER_MUTATION = 4#

TOKEN = 8#

USER_INPUT_MUTATION = 7#

USER_OUTPUT = 1#

class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source]#

arg_: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]_#

kind_: OutputKind_#

target_: Optional[str]_#

class torch.export.graph_signature.SymBoolArgument(name: str)[source]#

name_: str_#

class torch.export.graph_signature.SymFloatArgument(name: str)[source]#

name_: str_#

class torch.export.graph_signature.SymIntArgument(name: str)[source]#

name_: str_#

class torch.export.graph_signature.TensorArgument(name: str)[source]#

name_: str_#

class torch.export.graph_signature.TokenArgument(name: str)[source]#

name_: str_#