torch.jit.trace_module — PyTorch 2.7 documentation (original) (raw)

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[source][source]

Trace a module and return an executable ScriptModule that will be optimized using just-in-time compilation.

When a module is passed to torch.jit.trace, only the forward method is run and traced. With trace_module, you can specify a dictionary of method names to example inputs to trace (see the inputs) argument below.

See torch.jit.trace for more information on tracing.

Parameters

Keyword Arguments

Returns

A ScriptModule object with a single forward method containing the traced code. When func is a torch.nn.Module, the returned ScriptModule will have the same set of sub-modules and parameters as func.

Example (tracing a module with multiple methods):

import torch import torch.nn as nn

class Net(nn.Module): def init(self) -> None: super().init() self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
    return self.conv(x)

def weighted_kernel_sum(self, weight):
    return weight * self.conv.weight

n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3)

Trace a specific method and construct ScriptModule with

a single forward method

module = torch.jit.trace(n.forward, example_forward_input)

Trace a module (implicitly traces forward) and construct a

ScriptModule with a single forward method

module = torch.jit.trace(n, example_forward_input)

Trace specific methods on a module (specified in inputs), constructs

a ScriptModule with forward and weighted_kernel_sum methods

inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight} module = torch.jit.trace_module(n, inputs)