Models and pre-trained weights — Torchvision 0.16 documentation (original) (raw)

The torchvision.models subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection, video classification, and optical flow.

General information on pre-trained weights

TorchVision offers pre-trained weights for every provided architecture, using the PyTorch torch.hub. Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the TORCH_HOMEenvironment variable. See torch.hub.load_state_dict_from_url() for details.

Note

The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.

Note

Backward compatibility is guaranteed for loading a serializedstate_dict to the model created using old PyTorch version. On the contrary, loading entire saved models or serializedScriptModules (serialized using older versions of PyTorch) may not preserve the historic behaviour. Refer to the followingdocumentation

Initializing pre-trained models

As of v0.13, TorchVision offers a new Multi-weight support APIfor loading different weights to the existing model builder methods:

from torchvision.models import resnet50, ResNet50_Weights

Old weights with accuracy 76.130%

resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

New weights with accuracy 80.858%

resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

Best available weights (currently alias for IMAGENET1K_V2)

Note that these weights may change across versions

resnet50(weights=ResNet50_Weights.DEFAULT)

Strings are also supported

resnet50(weights="IMAGENET1K_V2")

No weights - random initialization

resnet50(weights=None)

Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:

from torchvision.models import resnet50, ResNet50_Weights

Using pretrained weights:

resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) resnet50(weights="IMAGENET1K_V1") resnet50(pretrained=True) # deprecated resnet50(True) # deprecated

Using no weights:

resnet50(weights=None) resnet50() resnet50(pretrained=False) # deprecated resnet50(False) # deprecated

Note that the pretrained parameter is now deprecated, using it will emit warnings and will be removed on v0.15.

Using the pre-trained models

Before using the pre-trained models, one must preprocess the image (resize with right resolution/interpolation, apply inference transforms, rescale the values etc). There is no standard way to do this as it depends on how a given model was trained. It can vary across model families, variants or even weight versions. Using the correct preprocessing method is critical and failing to do so may lead to decreased accuracy or incorrect outputs.

All the necessary information for the inference transforms of each pre-trained model is provided on its weights documentation. To simplify inference, TorchVision bundles the necessary preprocessing transforms into each model weight. These are accessible via the weight.transforms attribute:

Initialize the Weight Transforms

weights = ResNet50_Weights.DEFAULT preprocess = weights.transforms()

Apply it to the input image

img_transformed = preprocess(img)

Some models use modules which have different training and evaluation behavior, such as batch normalization. To switch between these modes, usemodel.train() or model.eval() as appropriate. Seetrain() or eval() for details.

Initialize model

weights = ResNet50_Weights.DEFAULT model = resnet50(weights=weights)

Set model to eval mode

model.eval()

Listing and retrieving available models

As of v0.14, TorchVision offers a new mechanism which allows listing and retrieving models and weights by their names. Here are a few examples on how to use them:

List available models

all_models = list_models() classification_models = list_models(module=torchvision.models)

Initialize models

m1 = get_model("mobilenet_v3_large", weights=None) m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")

Fetch weights

weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT") assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT

weights_enum = get_model_weights("quantized_mobilenet_v3_large") assert weights_enum == MobileNet_V3_Large_QuantizedWeights

weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large) assert weights_enum == weights_enum2

Here are the available public functions to retrieve models and their corresponding weights:

get_model(name, **config) Gets the model name and configuration and returns an instantiated model.
get_model_weights(name) Returns the weights enum class associated to the given model.
get_weight(name) Gets the weights enum value by its full name.
list_models([module, include, exclude]) Returns a list with the names of registered models.

Using models from Hub

Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:

import torch

Option 1: passing weights param as string

model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")

Option 2: passing weights param as enum

weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2") model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)

You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:

import torch

weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50") print([weight for weight in weight_enum])

The only exception to the above are the detection models included ontorchvision.models.detection. These models require TorchVision to be installed because they depend on custom C++ operators.

Classification

The following classification models are available, with or without pre-trained weights:

Here is an example of how to use the pre-trained image classification models:

from torchvision.io import read_image from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

Step 1: Initialize model with the best available weights

weights = ResNet50_Weights.DEFAULT model = resnet50(weights=weights) model.eval()

Step 2: Initialize the inference transforms

preprocess = weights.transforms()

Step 3: Apply inference preprocessing transforms

batch = preprocess(img).unsqueeze(0)

Step 4: Use the model and print the predicted category

prediction = model(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] print(f"{category_name}: {100 * score:.1f}%")

The classes of the pre-trained model outputs can be found at weights.meta["categories"].

Table of all available classification weights

Accuracies are reported on ImageNet-1K using single crops:

Weight Acc@1 Acc@5 Params GFLOPS Recipe
AlexNet_Weights.IMAGENET1K_V1 56.522 79.066 61.1M 0.71 link
ConvNeXt_Base_Weights.IMAGENET1K_V1 84.062 96.87 88.6M 15.36 link
ConvNeXt_Large_Weights.IMAGENET1K_V1 84.414 96.976 197.8M 34.36 link
ConvNeXt_Small_Weights.IMAGENET1K_V1 83.616 96.65 50.2M 8.68 link
ConvNeXt_Tiny_Weights.IMAGENET1K_V1 82.52 96.146 28.6M 4.46 link
DenseNet121_Weights.IMAGENET1K_V1 74.434 91.972 8.0M 2.83 link
DenseNet161_Weights.IMAGENET1K_V1 77.138 93.56 28.7M 7.73 link
DenseNet169_Weights.IMAGENET1K_V1 75.6 92.806 14.1M 3.36 link
DenseNet201_Weights.IMAGENET1K_V1 76.896 93.37 20.0M 4.29 link
EfficientNet_B0_Weights.IMAGENET1K_V1 77.692 93.532 5.3M 0.39 link
EfficientNet_B1_Weights.IMAGENET1K_V1 78.642 94.186 7.8M 0.69 link
EfficientNet_B1_Weights.IMAGENET1K_V2 79.838 94.934 7.8M 0.69 link
EfficientNet_B2_Weights.IMAGENET1K_V1 80.608 95.31 9.1M 1.09 link
EfficientNet_B3_Weights.IMAGENET1K_V1 82.008 96.054 12.2M 1.83 link
EfficientNet_B4_Weights.IMAGENET1K_V1 83.384 96.594 19.3M 4.39 link
EfficientNet_B5_Weights.IMAGENET1K_V1 83.444 96.628 30.4M 10.27 link
EfficientNet_B6_Weights.IMAGENET1K_V1 84.008 96.916 43.0M 19.07 link
EfficientNet_B7_Weights.IMAGENET1K_V1 84.122 96.908 66.3M 37.75 link
EfficientNet_V2_L_Weights.IMAGENET1K_V1 85.808 97.788 118.5M 56.08 link
EfficientNet_V2_M_Weights.IMAGENET1K_V1 85.112 97.156 54.1M 24.58 link
EfficientNet_V2_S_Weights.IMAGENET1K_V1 84.228 96.878 21.5M 8.37 link
GoogLeNet_Weights.IMAGENET1K_V1 69.778 89.53 6.6M 1.5 link
Inception_V3_Weights.IMAGENET1K_V1 77.294 93.45 27.2M 5.71 link
MNASNet0_5_Weights.IMAGENET1K_V1 67.734 87.49 2.2M 0.1 link
MNASNet0_75_Weights.IMAGENET1K_V1 71.18 90.496 3.2M 0.21 link
MNASNet1_0_Weights.IMAGENET1K_V1 73.456 91.51 4.4M 0.31 link
MNASNet1_3_Weights.IMAGENET1K_V1 76.506 93.522 6.3M 0.53 link
MaxVit_T_Weights.IMAGENET1K_V1 83.7 96.722 30.9M 5.56 link
MobileNet_V2_Weights.IMAGENET1K_V1 71.878 90.286 3.5M 0.3 link
MobileNet_V2_Weights.IMAGENET1K_V2 72.154 90.822 3.5M 0.3 link
MobileNet_V3_Large_Weights.IMAGENET1K_V1 74.042 91.34 5.5M 0.22 link
MobileNet_V3_Large_Weights.IMAGENET1K_V2 75.274 92.566 5.5M 0.22 link
MobileNet_V3_Small_Weights.IMAGENET1K_V1 67.668 87.402 2.5M 0.06 link
RegNet_X_16GF_Weights.IMAGENET1K_V1 80.058 94.944 54.3M 15.94 link
RegNet_X_16GF_Weights.IMAGENET1K_V2 82.716 96.196 54.3M 15.94 link
RegNet_X_1_6GF_Weights.IMAGENET1K_V1 77.04 93.44 9.2M 1.6 link
RegNet_X_1_6GF_Weights.IMAGENET1K_V2 79.668 94.922 9.2M 1.6 link
RegNet_X_32GF_Weights.IMAGENET1K_V1 80.622 95.248 107.8M 31.74 link
RegNet_X_32GF_Weights.IMAGENET1K_V2 83.014 96.288 107.8M 31.74 link
RegNet_X_3_2GF_Weights.IMAGENET1K_V1 78.364 93.992 15.3M 3.18 link
RegNet_X_3_2GF_Weights.IMAGENET1K_V2 81.196 95.43 15.3M 3.18 link
RegNet_X_400MF_Weights.IMAGENET1K_V1 72.834 90.95 5.5M 0.41 link
RegNet_X_400MF_Weights.IMAGENET1K_V2 74.864 92.322 5.5M 0.41 link
RegNet_X_800MF_Weights.IMAGENET1K_V1 75.212 92.348 7.3M 0.8 link
RegNet_X_800MF_Weights.IMAGENET1K_V2 77.522 93.826 7.3M 0.8 link
RegNet_X_8GF_Weights.IMAGENET1K_V1 79.344 94.686 39.6M 8 link
RegNet_X_8GF_Weights.IMAGENET1K_V2 81.682 95.678 39.6M 8 link
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1 88.228 98.682 644.8M 374.57 link
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 86.068 97.844 644.8M 127.52 link
RegNet_Y_16GF_Weights.IMAGENET1K_V1 80.424 95.24 83.6M 15.91 link
RegNet_Y_16GF_Weights.IMAGENET1K_V2 82.886 96.328 83.6M 15.91 link
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1 86.012 98.054 83.6M 46.73 link
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 83.976 97.244 83.6M 15.91 link
RegNet_Y_1_6GF_Weights.IMAGENET1K_V1 77.95 93.966 11.2M 1.61 link
RegNet_Y_1_6GF_Weights.IMAGENET1K_V2 80.876 95.444 11.2M 1.61 link
RegNet_Y_32GF_Weights.IMAGENET1K_V1 80.878 95.34 145.0M 32.28 link
RegNet_Y_32GF_Weights.IMAGENET1K_V2 83.368 96.498 145.0M 32.28 link
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1 86.838 98.362 145.0M 94.83 link
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 84.622 97.48 145.0M 32.28 link
RegNet_Y_3_2GF_Weights.IMAGENET1K_V1 78.948 94.576 19.4M 3.18 link
RegNet_Y_3_2GF_Weights.IMAGENET1K_V2 81.982 95.972 19.4M 3.18 link
RegNet_Y_400MF_Weights.IMAGENET1K_V1 74.046 91.716 4.3M 0.4 link
RegNet_Y_400MF_Weights.IMAGENET1K_V2 75.804 92.742 4.3M 0.4 link
RegNet_Y_800MF_Weights.IMAGENET1K_V1 76.42 93.136 6.4M 0.83 link
RegNet_Y_800MF_Weights.IMAGENET1K_V2 78.828 94.502 6.4M 0.83 link
RegNet_Y_8GF_Weights.IMAGENET1K_V1 80.032 95.048 39.4M 8.47 link
RegNet_Y_8GF_Weights.IMAGENET1K_V2 82.828 96.33 39.4M 8.47 link
ResNeXt101_32X8D_Weights.IMAGENET1K_V1 79.312 94.526 88.8M 16.41 link
ResNeXt101_32X8D_Weights.IMAGENET1K_V2 82.834 96.228 88.8M 16.41 link
ResNeXt101_64X4D_Weights.IMAGENET1K_V1 83.246 96.454 83.5M 15.46 link
ResNeXt50_32X4D_Weights.IMAGENET1K_V1 77.618 93.698 25.0M 4.23 link
ResNeXt50_32X4D_Weights.IMAGENET1K_V2 81.198 95.34 25.0M 4.23 link
ResNet101_Weights.IMAGENET1K_V1 77.374 93.546 44.5M 7.8 link
ResNet101_Weights.IMAGENET1K_V2 81.886 95.78 44.5M 7.8 link
ResNet152_Weights.IMAGENET1K_V1 78.312 94.046 60.2M 11.51 link
ResNet152_Weights.IMAGENET1K_V2 82.284 96.002 60.2M 11.51 link
ResNet18_Weights.IMAGENET1K_V1 69.758 89.078 11.7M 1.81 link
ResNet34_Weights.IMAGENET1K_V1 73.314 91.42 21.8M 3.66 link
ResNet50_Weights.IMAGENET1K_V1 76.13 92.862 25.6M 4.09 link
ResNet50_Weights.IMAGENET1K_V2 80.858 95.434 25.6M 4.09 link
ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1 60.552 81.746 1.4M 0.04 link
ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1 69.362 88.316 2.3M 0.14 link
ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1 72.996 91.086 3.5M 0.3 link
ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1 76.23 93.006 7.4M 0.58 link
SqueezeNet1_0_Weights.IMAGENET1K_V1 58.092 80.42 1.2M 0.82 link
SqueezeNet1_1_Weights.IMAGENET1K_V1 58.178 80.624 1.2M 0.35 link
Swin_B_Weights.IMAGENET1K_V1 83.582 96.64 87.8M 15.43 link
Swin_S_Weights.IMAGENET1K_V1 83.196 96.36 49.6M 8.74 link
Swin_T_Weights.IMAGENET1K_V1 81.474 95.776 28.3M 4.49 link
Swin_V2_B_Weights.IMAGENET1K_V1 84.112 96.864 87.9M 20.32 link
Swin_V2_S_Weights.IMAGENET1K_V1 83.712 96.816 49.7M 11.55 link
Swin_V2_T_Weights.IMAGENET1K_V1 82.072 96.132 28.4M 5.94 link
VGG11_BN_Weights.IMAGENET1K_V1 70.37 89.81 132.9M 7.61 link
VGG11_Weights.IMAGENET1K_V1 69.02 88.628 132.9M 7.61 link
VGG13_BN_Weights.IMAGENET1K_V1 71.586 90.374 133.1M 11.31 link
VGG13_Weights.IMAGENET1K_V1 69.928 89.246 133.0M 11.31 link
VGG16_BN_Weights.IMAGENET1K_V1 73.36 91.516 138.4M 15.47 link
VGG16_Weights.IMAGENET1K_V1 71.592 90.382 138.4M 15.47 link
VGG16_Weights.IMAGENET1K_FEATURES nan nan 138.4M 15.47 link
VGG19_BN_Weights.IMAGENET1K_V1 74.218 91.842 143.7M 19.63 link
VGG19_Weights.IMAGENET1K_V1 72.376 90.876 143.7M 19.63 link
ViT_B_16_Weights.IMAGENET1K_V1 81.072 95.318 86.6M 17.56 link
ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 85.304 97.65 86.9M 55.48 link
ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 81.886 96.18 86.6M 17.56 link
ViT_B_32_Weights.IMAGENET1K_V1 75.912 92.466 88.2M 4.41 link
ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 88.552 98.694 633.5M 1016.72 link
ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1 85.708 97.73 632.0M 167.29 link
ViT_L_16_Weights.IMAGENET1K_V1 79.662 94.638 304.3M 61.55 link
ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1 88.064 98.512 305.2M 361.99 link
ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 85.146 97.422 304.3M 61.55 link
ViT_L_32_Weights.IMAGENET1K_V1 76.972 93.07 306.5M 15.38 link
Wide_ResNet101_2_Weights.IMAGENET1K_V1 78.848 94.284 126.9M 22.75 link
Wide_ResNet101_2_Weights.IMAGENET1K_V2 82.51 96.02 126.9M 22.75 link
Wide_ResNet50_2_Weights.IMAGENET1K_V1 78.468 94.086 68.9M 11.4 link
Wide_ResNet50_2_Weights.IMAGENET1K_V2 81.602 95.758 68.9M 11.4 link

Quantized models

The following architectures provide support for INT8 quantized models, with or without pre-trained weights:

Here is an example of how to use the pre-trained quantized image classification models:

from torchvision.io import read_image from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

Step 1: Initialize model with the best available weights

weights = ResNet50_QuantizedWeights.DEFAULT model = resnet50(weights=weights, quantize=True) model.eval()

Step 2: Initialize the inference transforms

preprocess = weights.transforms()

Step 3: Apply inference preprocessing transforms

batch = preprocess(img).unsqueeze(0)

Step 4: Use the model and print the predicted category

prediction = model(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] print(f"{category_name}: {100 * score}%")

The classes of the pre-trained model outputs can be found at weights.meta["categories"].

Table of all available quantized classification weights

Accuracies are reported on ImageNet-1K using single crops:

Weight Acc@1 Acc@5 Params GIPS Recipe
GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 69.826 89.404 6.6M 1.5 link
Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 77.176 93.354 27.2M 5.71 link
MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 71.658 90.15 3.5M 0.3 link
MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 73.004 90.858 5.5M 0.22 link
ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 78.986 94.48 88.8M 16.41 link
ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V2 82.574 96.132 88.8M 16.41 link
ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 82.898 96.326 83.5M 15.46 link
ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 69.494 88.882 11.7M 1.81 link
ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 75.92 92.814 25.6M 4.09 link
ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2 80.282 94.976 25.6M 4.09 link
ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 57.972 79.78 1.4M 0.04 link
ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 68.36 87.582 2.3M 0.14 link
ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 72.052 90.7 3.5M 0.3 link
ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 75.354 92.488 7.4M 0.58 link

Semantic Segmentation

Warning

The segmentation module is in Beta stage, and backward compatibility is not guaranteed.

The following semantic segmentation models are available, with or without pre-trained weights:

Here is an example of how to use the pre-trained semantic segmentation models:

from torchvision.io.image import read_image from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights from torchvision.transforms.functional import to_pil_image

img = read_image("gallery/assets/dog1.jpg")

Step 1: Initialize model with the best available weights

weights = FCN_ResNet50_Weights.DEFAULT model = fcn_resnet50(weights=weights) model.eval()

Step 2: Initialize the inference transforms

preprocess = weights.transforms()

Step 3: Apply inference preprocessing transforms

batch = preprocess(img).unsqueeze(0)

Step 4: Use the model and visualize the prediction

prediction = model(batch)["out"] normalized_masks = prediction.softmax(dim=1) class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])} mask = normalized_masks[0, class_to_idx["dog"]] to_pil_image(mask).show()

The classes of the pre-trained model outputs can be found at weights.meta["categories"]. The output format of the models is illustrated in Semantic segmentation models.

Table of all available semantic segmentation weights

All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:

Weight Mean IoU pixelwise Acc Params GFLOPS Recipe
DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1 60.3 91.2 11.0M 10.45 link
DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1 67.4 92.4 61.0M 258.74 link
DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1 66.4 92.4 42.0M 178.72 link
FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1 63.7 91.9 54.3M 232.74 link
FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1 60.5 91.4 35.3M 152.72 link
LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1 57.9 91.2 3.2M 2.09 link

Object Detection, Instance Segmentation and Person Keypoint Detection

The pre-trained models for detection, instance segmentation and keypoint detection are initialized with the classification models in torchvision. The models expect a list of Tensor[C, H, W]. Check the constructor of the models for more information.

Warning

The detection module is in Beta stage, and backward compatibility is not guaranteed.

Object Detection

The following object detection models are available, with or without pre-trained weights:

Here is an example of how to use the pre-trained object detection models:

from torchvision.io.image import read_image from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes from torchvision.transforms.functional import to_pil_image

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

Step 1: Initialize model with the best available weights

weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) model.eval()

Step 2: Initialize the inference transforms

preprocess = weights.transforms()

Step 3: Apply inference preprocessing transforms

batch = [preprocess(img)]

Step 4: Use the model and visualize the prediction

prediction = model(batch)[0] labels = [weights.meta["categories"][i] for i in prediction["labels"]] box = draw_bounding_boxes(img, boxes=prediction["boxes"], labels=labels, colors="red", width=4, font_size=30) im = to_pil_image(box.detach()) im.show()

The classes of the pre-trained model outputs can be found at weights.meta["categories"]. For details on how to plot the bounding boxes of the models, you may refer to Instance segmentation models.

Table of all available Object detection weights

Box MAPs are reported on COCO val2017:

Weight Box MAP Params GFLOPS Recipe
FCOS_ResNet50_FPN_Weights.COCO_V1 39.2 32.3M 128.21 link
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1 22.8 19.4M 0.72 link
FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1 32.8 19.4M 4.49 link
FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1 46.7 43.7M 280.37 link
FasterRCNN_ResNet50_FPN_Weights.COCO_V1 37 41.8M 134.38 link
RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1 41.5 38.2M 152.24 link
RetinaNet_ResNet50_FPN_Weights.COCO_V1 36.4 34.0M 151.54 link
SSD300_VGG16_Weights.COCO_V1 25.1 35.6M 34.86 link
SSDLite320_MobileNet_V3_Large_Weights.COCO_V1 21.3 3.4M 0.58 link

Instance Segmentation

The following instance segmentation models are available, with or without pre-trained weights:

For details on how to plot the masks of the models, you may refer to Instance segmentation models.

Table of all available Instance segmentation weights

Box and Mask MAPs are reported on COCO val2017:

Weight Box MAP Mask MAP Params GFLOPS Recipe
MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1 47.4 41.8 46.4M 333.58 link
MaskRCNN_ResNet50_FPN_Weights.COCO_V1 37.9 34.6 44.4M 134.38 link

Keypoint Detection

The following person keypoint detection models are available, with or without pre-trained weights:

The classes of the pre-trained model outputs can be found at weights.meta["keypoint_names"]. For details on how to plot the bounding boxes of the models, you may refer to Visualizing keypoints.

Table of all available Keypoint detection weights

Box and Keypoint MAPs are reported on COCO val2017:

Weight Box MAP Keypoint MAP Params GFLOPS Recipe
KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY 50.6 61.1 59.1M 133.92 link
KeypointRCNN_ResNet50_FPN_Weights.COCO_V1 54.6 65 59.1M 137.42 link

Video Classification

Warning

The video module is in Beta stage, and backward compatibility is not guaranteed.

The following video classification models are available, with or without pre-trained weights:

Here is an example of how to use the pre-trained video classification models:

from torchvision.io.video import read_video from torchvision.models.video import r3d_18, R3D_18_Weights

vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW") vid = vid[:32] # optionally shorten duration

Step 1: Initialize model with the best available weights

weights = R3D_18_Weights.DEFAULT model = r3d_18(weights=weights) model.eval()

Step 2: Initialize the inference transforms

preprocess = weights.transforms()

Step 3: Apply inference preprocessing transforms

batch = preprocess(vid).unsqueeze(0)

Step 4: Use the model and print the predicted category

prediction = model(batch).squeeze(0).softmax(0) label = prediction.argmax().item() score = prediction[label].item() category_name = weights.meta["categories"][label] print(f"{category_name}: {100 * score}%")

The classes of the pre-trained model outputs can be found at weights.meta["categories"].

Table of all available video classification weights

Accuracies are reported on Kinetics-400 using single crops for clip length 16:

Weight Acc@1 Acc@5 Params GFLOPS Recipe
MC3_18_Weights.KINETICS400_V1 63.96 84.13 11.7M 43.34 link
MViT_V1_B_Weights.KINETICS400_V1 78.477 93.582 36.6M 70.6 link
MViT_V2_S_Weights.KINETICS400_V1 80.757 94.665 34.5M 64.22 link
R2Plus1D_18_Weights.KINETICS400_V1 67.463 86.175 31.5M 40.52 link
R3D_18_Weights.KINETICS400_V1 63.2 83.479 33.4M 40.7 link
S3D_Weights.KINETICS400_V1 68.368 88.05 8.3M 17.98 link
Swin3D_B_Weights.KINETICS400_V1 79.427 94.386 88.0M 140.67 link
Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1 81.643 95.574 88.0M 140.67 link
Swin3D_S_Weights.KINETICS400_V1 79.521 94.158 49.8M 82.84 link
Swin3D_T_Weights.KINETICS400_V1 77.715 93.519 28.2M 43.88 link

Optical Flow

The following Optical Flow models are available, with or without pre-trained