torch.jit.load — PyTorch 2.7 documentation (original) (raw)

torch.jit.load(f, map_location=None, _extra_files=None, _restore_shapes=False)[source][source]

Load a ScriptModule or ScriptFunction previously saved with torch.jit.save.

All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised.

Parameters

Returns

A ScriptModule object.

Warning

It is possible to construct malicious pickle data which will execute arbitrary code during func:torch.jit.load. Never load data that could have come from an untrusted source, or that could have been tampered with. Only load data you trust.

Example: .. testcode:

import torch import io

torch.jit.load('scriptmodule.pt')

Load ScriptModule from io.BytesIO object

with open('scriptmodule.pt', 'rb') as f: buffer = io.BytesIO(f.read())

Load all tensors to the original device

torch.jit.load(buffer)

Load all tensors onto CPU, using a device

buffer.seek(0) torch.jit.load(buffer, map_location=torch.device('cpu'))

Load all tensors onto CPU, using a string

buffer.seek(0) torch.jit.load(buffer, map_location='cpu')

Load with extra files.

extra_files = {'foo.txt': ''} # values will be replaced with data torch.jit.load('scriptmodule.pt', _extra_files=extra_files) print(extra_files['foo.txt'])