torch.jit.trace — PyTorch 2.7 documentation (original) (raw)

torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[source][source]

Trace a function and return an executable or ScriptFunction that will be optimized using just-in-time compilation.

Tracing is ideal for code that operates only onTensor\s and lists, dictionaries, and tuples of Tensor\s.

Using torch.jit.trace and torch.jit.trace_module, you can turn an existing module or Python function into a TorchScriptScriptFunction or ScriptModule. You must provide example inputs, and we run the function, recording the operations performed on all the tensors.

This module also contains any parameters that the original module had as well.

Warning

Tracing only correctly records functions and modules which are not data dependent (e.g., do not have conditionals on data in tensors) and do not have any untracked external dependencies (e.g., perform input/output or access global variables). Tracing only records operations done when the given function is run on the given tensors. Therefore, the returnedScriptModule will always run the same traced graph on any input. This has some important implications when your module is expected to run different sets of operations, depending on the input and/or the module state. For example,

In cases like these, tracing would not be appropriate andscripting is a better choice. If you trace such models, you may silently get incorrect results on subsequent invocations of the model. The tracer will try to emit warnings when doing something that may cause an incorrect trace to be produced.

Parameters

func (callable or torch.nn.Module) – A Python function or torch.nn.Modulethat will be run with example_inputs. func arguments and return values must be tensors or (possibly nested) tuples that contain tensors. When a module is passed torch.jit.trace, only theforward method is run and traced (see torch.jit.trace for details).

Keyword Arguments

Returns

If func is nn.Module or forward of nn.Module, trace returns a ScriptModule object with a single forward method containing the traced code. The returned ScriptModule will have the same set of sub-modules and parameters as the originalnn.Module. If func is a standalone function, tracereturns ScriptFunction.

Example (tracing a function):

import torch

def foo(x, y): return 2 * x + y

Run foo with the provided inputs and record the tensor operations

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

traced_foo can now be run with the TorchScript interpreter or saved

and loaded in a Python-free environment

Example (tracing an existing module):

import torch import torch.nn as nn

class Net(nn.Module): def init(self) -> None: super().init() self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
    return self.conv(x)

n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3)

Trace a specific method and construct ScriptModule with

a single forward method

module = torch.jit.trace(n.forward, example_forward_input)

Trace a module (implicitly traces forward) and construct a

ScriptModule with a single forward method

module = torch.jit.trace(n, example_forward_input)