torch.jit.freeze — PyTorch 2.7 documentation (original) (raw)

torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[source][source]

Freeze ScriptModule, inline submodules, and attributes as constants.

Freezing a ScriptModule will clone it and attempt to inline the cloned module’s submodules, parameters, and attributes as constants in the TorchScript IR Graph. By default, forward will be preserved, as well as attributes & methods specified inpreserved_attrs. Additionally, any attribute that is modified within a preserved method will be preserved.

Freezing currently only accepts ScriptModules that are in eval mode.

Freezing applies generic optimization that will speed up your model regardless of machine. To further optimize using server-specific settings, run optimize_for_inference after freezing.

Parameters

Returns

Frozen ScriptModule.

Example (Freezing a simple module with a Parameter):

def forward(self, input):
    output = self.weight.mm(input)
    output = self.linear(output)
    return output

scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module)

parameters have been removed and inlined into the Graph as constants

assert len(list(frozen_module.named_parameters())) == 0

See the compiled graph as Python code

print(frozen_module.code)

Example (Freezing a module with preserved attributes)

def forward(self, input):
    self.modified_tensor += 1
    return input + self.modified_tensor

scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])

we've manually preserved version, so it still exists on the frozen module and can be modified

assert frozen_module.version == 1 frozen_module.version = 2

modified_tensor is detected as being mutated in the forward, so freezing preserves

it to retain model semantics

assert frozen_module(torch.tensor(1)) == torch.tensor(12)

now that we've run it once, the next result will be incremented by one

assert frozen_module(torch.tensor(1)) == torch.tensor(13)

Note

Freezing submodule attributes is also supported: frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=[“submodule.version”])

Note

If you’re not sure why an attribute is not being inlined as a constant, you can rundump_alias_db on frozen_module.forward.graph to see if freezing has detected the attribute is being modified.

Note

Because freezing makes weights constants and removes module hierarchy, to and other nn.Module methods to manipulate device or dtype no longer work. As a workaround, You can remap devices by specifying map_location in torch.jit.load, however device-specific logic may have been baked into the model.