torch.hub — PyTorch 2.7 documentation (original) (raw)

Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.

Publishing models

Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a GitHub repository by adding a simple hubconf.py file;

hubconf.py can have multiple entrypoints. Each entrypoint is defined as a python function (example: a pre-trained model you want to publish).

def entrypoint_name(*args, **kwargs): # args & kwargs are optional, for models which take positional/keyword arguments. ...

How to implement an entrypoint?

Here is a code snippet specifies an entrypoint for resnet18 model if we expand the implementation in pytorch/vision/hubconf.py. In most case importing the right function in hubconf.py is sufficient. Here we just want to use the expanded version as an example to show how it works. You can see the full script inpytorch/vision repo

dependencies = ['torch'] from torchvision.models.resnet import resnet18 as _resnet18

resnet18 is the name of entrypoint

def resnet18(pretrained=False, **kwargs): """ # This docstring shows up in hub.help() Resnet18 model pretrained (bool): kwargs, load pretrained weights into the model """ # Call the model, load pretrained weights model = _resnet18(pretrained=pretrained, **kwargs) return model

if pretrained: # For checkpoint saved in local GitHub repo, e.g. =weights/save.pth dirname = os.path.dirname(file) checkpoint = os.path.join(dirname, ) state_dict = torch.load(checkpoint) model.load_state_dict(state_dict)

# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

Important Notice

Loading models from Hub

Pytorch Hub provides convenient APIs to explore all available models in hub through torch.hub.list(), show docstring and examples throughtorch.hub.help() and load the pre-trained models usingtorch.hub.load().

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source][source]

List all callable entrypoints available in the repo specified by github.

Parameters

Returns

The available callables entrypoint

Return type

list

Example

entrypoints = torch.hub.list("pytorch/vision", force_reload=True)

torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source][source]

Show the docstring of entrypoint model.

Parameters

Example

print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))

torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source][source]

Load a model from a github repo or a local directory.

Note: Loading a model is the typical use case, but this can also be used to for loading other objects such as tokenizers, loss functions, etc.

If source is ‘github’, repo_or_dir is expected to be of the form repo_owner/repo_name[:ref] with an optional ref (a tag or a branch).

If source is ‘local’, repo_or_dir is expected to be a path to a local directory.

Parameters

Returns

The output of the model callable when called with the given*args and **kwargs.

Example

from a github repo

repo = "pytorch/vision" model = torch.hub.load( ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" ... )

from a local directory

path = "/some/local/path/pytorch/vision" model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")

torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source][source]

Download object at the given URL to a local path.

Parameters

Example

torch.hub.download_url_to_file( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", ... "/tmp/temporary_file", ... )

torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source][source]

Loads the Torch serialized object at the given URL.

If downloaded file is a zip file, it will be automatically decompressed.

If the object is already present in model_dir, it’s deserialized and returned. The default value of model_dir is <hub_dir>/checkpoints wherehub_dir is the directory returned by get_dir().

Parameters

Return type

dict[str, Any]

Example

state_dict = torch.hub.load_state_dict_from_url( ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" ... )

Running a loaded model:

Note that *args and **kwargs in torch.hub.load() are used toinstantiate a model. After you have loaded a model, how can you find out what you can do with the model? A suggested workflow is

To help users explore without referring to documentation back and forth, we strongly recommend repo owners make function help messages clear and succinct. It’s also helpful to include a minimal working example.

Where are my downloaded models saved?

The locations are used in the order of

torch.hub.get_dir()[source][source]

Get the Torch Hub cache directory used for storing downloaded models & weights.

If set_dir() is not called, default path is $TORCH_HOME/hub where environment variable $TORCH_HOME defaults to $XDG_CACHE_HOME/torch.$XDG_CACHE_HOME follows the X Design Group specification of the Linux filesystem layout, with a default value ~/.cache if the environment variable is not set.

Return type

str

torch.hub.set_dir(d)[source][source]

Optionally set the Torch Hub directory used to save downloaded models & weights.

Parameters

d (str) – path to a local folder to save downloaded models & weights.

Caching logic

By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in the directory returned by get_dir().

Users can force a reload by calling hub.load(..., force_reload=True). This will delete the existing GitHub folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release.

Known limitations:

Torch hub works by importing the package as if it was installed. There are some side effects introduced by importing in Python. For example, you can see new items in Python cachessys.modules and sys.path_importer_cache which is normal Python behavior. This also means that you may have import errors when importing different models from different repos, if the repos have the same sub-package names (typically, amodel subpackage). A workaround for these kinds of import errors is to remove the offending sub-package from the sys.modules dict; more details can be found in this GitHub issue.

A known limitation that is worth mentioning here: users CANNOT load two different branches of the same repo in the same python process. It’s just like installing two packages with the same name in Python, which is not good. Cache might join the party and give you surprises if you actually try that. Of course it’s totally fine to load them in separate processes.