Transforming and augmenting images — Torchvision 0.22 documentation (original) (raw)

Torchvision supports common computer vision transformations in thetorchvision.transforms and torchvision.transforms.v2 modules. Transforms can be used to transform or augment data for training or inference of different tasks (image classification, detection, segmentation, video classification).

Image Classification

import torch from torchvision.transforms import v2

H, W = 32, 32 img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

transforms = v2.Compose([ v2.RandomResizedCrop(size=(224, 224), antialias=True), v2.RandomHorizontalFlip(p=0.5), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = transforms(img)

Detection (re-using imports and transforms from above)

from torchvision import tv_tensors

img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8) boxes = torch.randint(0, H // 2, size=(3, 4)) boxes[:, 2:] += boxes[:, :2] boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))

The same transforms can be used!

img, boxes = transforms(img, boxes)

And you can pass arbitrary input structures

output_dict = transforms({"image": img, "boxes": boxes})

Transforms are typically passed as the transform or transforms argument to the Datasets.

Start here

Whether you’re new to Torchvision transforms, or you’re already experienced with them, we encourage you to start withGetting started with transforms v2 in order to learn more about what can be done with the new v2 transforms.

Then, browse the sections in below this page for general information and performance tips. The available transforms and functionals are listed in theAPI reference.

More information and tutorials can also be found in our example gallery, e.g. Transforms v2: End-to-end object detection/segmentation exampleor How to write your own v2 transforms.

Supported input types and conventions

Most transformations accept both PIL images and tensor inputs. Both CPU and CUDA tensors are supported. The result of both backends (PIL or Tensors) should be very close. In general, we recommend relying on the tensor backend for performance. The conversion transforms may be used to convert to and from PIL images, or for converting dtypes and ranges.

Tensor image are expected to be of shape (C, H, W), where C is the number of channels, and H and W refer to height and width. Most transforms support batched tensor input. A batch of Tensor images is a tensor of shape (N, C, H, W), where N is a number of images in the batch. Thev2 transforms generally accept an arbitrary number of leading dimensions (..., C, H, W) and can handle batched images or batched videos.

Dtype and expected value range

The expected range of the values of a tensor image is implicitly defined by the tensor dtype. Tensor images with a float dtype are expected to have values in [0, 1]. Tensor images with an integer dtype are expected to have values in [0, MAX_DTYPE] where MAX_DTYPE is the largest value that can be represented in that dtype. Typically, images of dtypetorch.uint8 are expected to have values in [0, 255].

Use ToDtype to convert both the dtype and range of the inputs.

V1 or V2? Which one should I use?

TL;DR We recommending using the torchvision.transforms.v2 transforms instead of those in torchvision.transforms. They’re faster and they can do more things. Just change the import and you should be good to go. Moving forward, new features and improvements will only be considered for the v2 transforms.

In Torchvision 0.15 (March 2023), we released a new set of transforms available in the torchvision.transforms.v2 namespace. These transforms have a lot of advantages compared to the v1 ones (in torchvision.transforms):

These transforms are fully backward compatible with the v1 ones, so if you’re already using tranforms from torchvision.transforms, all you need to do to is to update the import to torchvision.transforms.v2. In terms of output, there might be negligible differences due to implementation differences.

Performance considerations

We recommend the following guidelines to get the best performance out of the transforms:

This is what a typical transform pipeline could look like:

from torchvision.transforms import v2 transforms = v2.Compose([ v2.ToImage(), # Convert to tensor, only needed if you had a PIL image v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point # ... v2.RandomResizedCrop(size=(224, 224), antialias=True), # Or Resize(antialias=True) # ... v2.ToDtype(torch.float32, scale=True), # Normalize expects float input v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])

The above should give you the best performance in a typical training environment that relies on the torch.utils.data.DataLoader with num_workers > 0.

Transforms tend to be sensitive to the input strides / memory format. Some transforms will be faster with channels-first images while others prefer channels-last. Like torch operators, most transforms will preserve the memory format of the input, but this may not always be respected due to implementation details. You may want to experiment a bit if you’re chasing the very best performance. Using torch.compile() on individual transforms may also help factoring out the memory format variable (e.g. onNormalize). Note that we’re talking aboutmemory format, not tensor shape.

Note that resize transforms like Resizeand RandomResizedCrop typically prefer channels-last input and tend not to benefit from torch.compile() at this time.

Transform classes, functionals, and kernels

Transforms are available as classes likeResize, but also as functionals likeresize() in thetorchvision.transforms.v2.functional namespace. This is very much like the torch.nn package which defines both classes and functional equivalents in torch.nn.functional.

The functionals support PIL images, pure tensors, or TVTensors, e.g. both resize(image_tensor) and resize(boxes) are valid.

Note

Random transforms like RandomCrop will randomly sample some parameter each time they’re called. Their functional counterpart (crop()) does not do any kind of random sampling and thus have a slighlty different parametrization. The get_params() class method of the transforms class can be used to perform parameter sampling when using the functional APIs.

The torchvision.transforms.v2.functional namespace also contains what we call the “kernels”. These are the low-level functions that implement the core functionalities for specific types, e.g. resize_bounding_boxes or`resized_crop_mask. They are public, although not documented. Check thecodeto see which ones are available (note that those starting with a leading underscore are not public!). Kernels are only really useful if you wanttorchscript support for types like bounding boxes or masks.

Torchscript support

Most transform classes and functionals support torchscript. For composing transforms, use torch.nn.Sequential instead ofCompose:

transforms = torch.nn.Sequential( CenterCrop(10), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ) scripted_transforms = torch.jit.script(transforms)

Warning

v2 transforms support torchscript, but if you call torch.jit.script() on a v2 class transform, you’ll actually end up with its (scripted) v1 equivalent. This may lead to slightly different results between the scripted and eager executions due to implementation differences between v1 and v2.

If you really need torchscript support for the v2 transforms, we recommend scripting the functionals from thetorchvision.transforms.v2.functional namespace to avoid surprises.

Also note that the functionals only support torchscript for pure tensors, which are always treated as images. If you need torchscript support for other types like bounding boxes or masks, you can rely on the low-level kernels.

For any custom transformations to be used with torch.jit.script, they should be derived from torch.nn.Module.

See also: Torchscript support.

V2 API reference - Recommended

Geometry

Resizing

Functionals

Cropping

v2.RandomCrop(size[, padding, ...]) Crop the input at a random location.
v2.RandomResizedCrop(size[, scale, ratio, ...]) Crop a random portion of the input and resize it to a given size.
v2.RandomIoUCrop([min_scale, max_scale, ...]) Random IoU crop transformation from "SSD: Single Shot MultiBox Detector".
v2.CenterCrop(size) Crop the input at the center.
v2.FiveCrop(size) Crop the image or video into four corners and the central crop.
v2.TenCrop(size[, vertical_flip]) Crop the image or video into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default).

Functionals

v2.functional.crop(inpt, top, left, height, ...) See RandomCrop for details.
v2.functional.resized_crop(inpt, top, left, ...) See RandomResizedCrop for details.
v2.functional.ten_crop(inpt, size[, ...]) See TenCrop for details.
v2.functional.center_crop(inpt, output_size) See RandomCrop for details.
v2.functional.five_crop(inpt, size) See FiveCrop for details.

Others

v2.RandomHorizontalFlip([p]) Horizontally flip the input with a given probability.
v2.RandomVerticalFlip([p]) Vertically flip the input with a given probability.
v2.Pad(padding[, fill, padding_mode]) Pad the input on all sides with the given "pad" value.
v2.RandomZoomOut([fill, side_range, p]) "Zoom out" transformation from "SSD: Single Shot MultiBox Detector".
v2.RandomRotation(degrees[, interpolation, ...]) Rotate the input by angle.
v2.RandomAffine(degrees[, translate, scale, ...]) Random affine transformation the input keeping center invariant.
v2.RandomPerspective([distortion_scale, p, ...]) Perform a random perspective transformation of the input with a given probability.
v2.ElasticTransform([alpha, sigma, ...]) Transform the input with elastic transformations.

Functionals

v2.functional.horizontal_flip(inpt) See RandomHorizontalFlip for details.
v2.functional.vertical_flip(inpt) See RandomVerticalFlip for details.
v2.functional.pad(inpt, padding[, fill, ...]) See Pad for details.
v2.functional.rotate(inpt, angle[, ...]) See RandomRotation for details.
v2.functional.affine(inpt, angle, translate, ...) See RandomAffine for details.
v2.functional.perspective(inpt, startpoints, ...) See RandomPerspective for details.
v2.functional.elastic(inpt, displacement[, ...]) See ElasticTransform for details.

Color

v2.ColorJitter([brightness, contrast, ...]) Randomly change the brightness, contrast, saturation and hue of an image or video.
v2.RandomChannelPermutation() Randomly permute the channels of an image or video
v2.RandomPhotometricDistort([brightness, ...]) Randomly distorts the image or video as used in SSD: Single Shot MultiBox Detector.
v2.Grayscale([num_output_channels]) Convert images or videos to grayscale.
v2.RGB() Convert images or videos to RGB (if they are already not RGB).
v2.RandomGrayscale([p]) Randomly convert image or videos to grayscale with a probability of p (default 0.1).
v2.GaussianBlur(kernel_size[, sigma]) Blurs image with randomly chosen Gaussian blur kernel.
v2.GaussianNoise([mean, sigma, clip]) Add gaussian noise to images or videos.
v2.RandomInvert([p]) Inverts the colors of the given image or video with a given probability.
v2.RandomPosterize(bits[, p]) Posterize the image or video with a given probability by reducing the number of bits for each color channel.
v2.RandomSolarize(threshold[, p]) Solarize the image or video with a given probability by inverting all pixel values above a threshold.
v2.RandomAdjustSharpness(sharpness_factor[, p]) Adjust the sharpness of the image or video with a given probability.
v2.RandomAutocontrast([p]) Autocontrast the pixels of the given image or video with a given probability.
v2.RandomEqualize([p]) Equalize the histogram of the given image or video with a given probability.

Functionals

v2.functional.permute_channels(inpt, permutation) Permute the channels of the input according to the given permutation.
v2.functional.rgb_to_grayscale(inpt[, ...]) See Grayscale for details.
v2.functional.grayscale_to_rgb(inpt) See RGB for details.
v2.functional.to_grayscale(inpt[, ...]) See Grayscale for details.
v2.functional.gaussian_blur(inpt, kernel_size) See GaussianBlur for details.
v2.functional.gaussian_noise(inpt[, mean, ...]) See GaussianNoise
v2.functional.invert(inpt) See RandomInvert().
v2.functional.posterize(inpt, bits) See RandomPosterize for details.
v2.functional.solarize(inpt, threshold) See RandomSolarize for details.
v2.functional.adjust_sharpness(inpt, ...) See RandomAdjustSharpness
v2.functional.autocontrast(inpt) See RandomAutocontrast for details.
v2.functional.adjust_contrast(inpt, ...) See RandomAutocontrast
v2.functional.equalize(inpt) See RandomEqualize for details.
v2.functional.adjust_brightness(inpt, ...) Adjust brightness.
v2.functional.adjust_saturation(inpt, ...) Adjust saturation.
v2.functional.adjust_hue(inpt, hue_factor) Adjust hue
v2.functional.adjust_gamma(inpt, gamma[, gain]) Adjust gamma.

Composition

v2.Compose(transforms) Composes several transforms together.
v2.RandomApply(transforms[, p]) Apply randomly a list of transformations with a given probability.
v2.RandomChoice(transforms[, p]) Apply single transformation randomly picked from a list.
v2.RandomOrder(transforms) Apply a list of transformations in a random order.

Miscellaneous

v2.LinearTransformation(...) Transform a tensor image or video with a square transformation matrix and a mean_vector computed offline.
v2.Normalize(mean, std[, inplace]) Normalize a tensor image or video with mean and standard deviation.
v2.RandomErasing([p, scale, ratio, value, ...]) Randomly select a rectangle region in the input image or video and erase its pixels.
v2.Lambda(lambd, *types) Apply a user-defined function as a transform.
v2.SanitizeBoundingBoxes([min_size, ...]) Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
v2.ClampBoundingBoxes() Clamp bounding boxes to their corresponding image dimensions.
v2.UniformTemporalSubsample(num_samples) Uniformly subsample num_samples indices from the temporal dimension of the video.
v2.JPEG(quality) Apply JPEG compression and decompression to the given images.

Functionals

v2.functional.normalize(inpt, mean, std[, ...]) See Normalize for details.
v2.functional.erase(inpt, i, j, h, w, v[, ...]) See RandomErase for details.
v2.functional.sanitize_bounding_boxes(...[, ...]) Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
v2.functional.clamp_bounding_boxes(inpt[, ...]) See ClampBoundingBoxes() for details.
v2.functional.uniform_temporal_subsample(...) See UniformTemporalSubsample for details.
v2.functional.jpeg(image, quality) See JPEG for details.

Conversion

Note

Beware, some of these conversion transforms below will scale the values while performing the conversion, while some may not do any scaling. By scaling, we mean e.g. that a uint8 -> float32 would map the [0, 255] range into [0, 1] (and vice-versa). See Dtype and expected value range.

v2.ToImage() Convert a tensor, ndarray, or PIL Image to Image ; this does not scale values.
v2.ToPureTensor() Convert all TVTensors to pure tensors, removing associated metadata (if any).
v2.PILToTensor() Convert a PIL Image to a tensor of the same type - this does not scale values.
v2.ToPILImage([mode]) Convert a tensor or an ndarray to PIL Image
v2.ToDtype(dtype[, scale]) Converts the input to a specific dtype, optionally scaling the values for images or videos.
v2.ConvertBoundingBoxFormat(format) Convert bounding box coordinates to the given format, eg from "CXCYWH" to "XYXY".

functionals

Deprecated

Auto-Augmentation

AutoAugment is a common Data Augmentation technique that can improve the accuracy of Image Classification models. Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that ImageNet policies provide significant improvements when applied to other datasets. In TorchVision we implemented 3 policies learned on the following datasets: ImageNet, CIFAR10 and SVHN. The new transform can be used standalone or mixed-and-matched with existing transforms:

CutMix - MixUp

CutMix and MixUp are special transforms that are meant to be used on batches rather than on individual images, because they are combining pairs of images together. These can be used after the dataloader (once the samples are batched), or part of a collation function. SeeHow to use CutMix and MixUp for detailed usage examples.

v2.CutMix(*[, alpha, num_classes, labels_getter]) Apply CutMix to the provided batch of images and labels.
v2.MixUp(*[, alpha, num_classes, labels_getter]) Apply MixUp to the provided batch of images and labels.

Developer tools

v2.Transform() Base class to implement your own v2 transforms.

V1 API Reference

Geometry

Resize(size[, interpolation, max_size, ...]) Resize the input image to the given size.
RandomCrop(size[, padding, pad_if_needed, ...]) Crop the given image at a random location.
RandomResizedCrop(size[, scale, ratio, ...]) Crop a random portion of image and resize it to a given size.
CenterCrop(size) Crops the given image at the center.
FiveCrop(size) Crop the given image into four corners and the central crop.
TenCrop(size[, vertical_flip]) Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default).
Pad(padding[, fill, padding_mode]) Pad the given image on all sides with the given "pad" value.
RandomRotation(degrees[, interpolation, ...]) Rotate the image by angle.
RandomAffine(degrees[, translate, scale, ...]) Random affine transformation of the image keeping center invariant.
RandomPerspective([distortion_scale, p, ...]) Performs a random perspective transformation of the given image with a given probability.
ElasticTransform([alpha, sigma, ...]) Transform a tensor image with elastic transformations.
RandomHorizontalFlip([p]) Horizontally flip the given image randomly with a given probability.
RandomVerticalFlip([p]) Vertically flip the given image randomly with a given probability.

Color

ColorJitter([brightness, contrast, ...]) Randomly change the brightness, contrast, saturation and hue of an image.
Grayscale([num_output_channels]) Convert image to grayscale.
RandomGrayscale([p]) Randomly convert image to grayscale with a probability of p (default 0.1).
GaussianBlur(kernel_size[, sigma]) Blurs image with randomly chosen Gaussian blur.
RandomInvert([p]) Inverts the colors of the given image randomly with a given probability.
RandomPosterize(bits[, p]) Posterize the image randomly with a given probability by reducing the number of bits for each color channel.
RandomSolarize(threshold[, p]) Solarize the image randomly with a given probability by inverting all pixel values above a threshold.
RandomAdjustSharpness(sharpness_factor[, p]) Adjust the sharpness of the image randomly with a given probability.
RandomAutocontrast([p]) Autocontrast the pixels of the given image randomly with a given probability.
RandomEqualize([p]) Equalize the histogram of the given image randomly with a given probability.

Composition

Compose(transforms) Composes several transforms together.
RandomApply(transforms[, p]) Apply randomly a list of transformations with a given probability.
RandomChoice(transforms[, p]) Apply single transformation randomly picked from a list.
RandomOrder(transforms) Apply a list of transformations in a random order.

Miscellaneous

LinearTransformation(transformation_matrix, ...) Transform a tensor image with a square transformation matrix and a mean_vector computed offline.
Normalize(mean, std[, inplace]) Normalize a tensor image with mean and standard deviation.
RandomErasing([p, scale, ratio, value, inplace]) Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
Lambda(lambd) Apply a user-defined lambda as a transform.

Conversion

Note

Beware, some of these conversion transforms below will scale the values while performing the conversion, while some may not do any scaling. By scaling, we mean e.g. that a uint8 -> float32 would map the [0, 255] range into [0, 1] (and vice-versa). See Dtype and expected value range.

ToPILImage([mode]) Convert a tensor or an ndarray to PIL Image
ToTensor() Convert a PIL Image or ndarray to tensor and scale the values accordingly.
PILToTensor() Convert a PIL Image to a tensor of the same type - this does not scale values.
ConvertImageDtype(dtype) Convert a tensor image to the given dtype and scale the values accordingly.

Auto-Augmentation

AutoAugment is a common Data Augmentation technique that can improve the accuracy of Image Classification models. Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that ImageNet policies provide significant improvements when applied to other datasets. In TorchVision we implemented 3 policies learned on the following datasets: ImageNet, CIFAR10 and SVHN. The new transform can be used standalone or mixed-and-matched with existing transforms:

Functional Transforms

adjust_brightness(img, brightness_factor) Adjust brightness of an image.
adjust_contrast(img, contrast_factor) Adjust contrast of an image.
adjust_gamma(img, gamma[, gain]) Perform gamma correction on an image.
adjust_hue(img, hue_factor) Adjust hue of an image.
adjust_saturation(img, saturation_factor) Adjust color saturation of an image.
adjust_sharpness(img, sharpness_factor) Adjust the sharpness of an image.
affine(img, angle, translate, scale, shear) Apply affine transformation on the image keeping image center invariant.
autocontrast(img) Maximize contrast of an image by remapping its pixels per channel so that the lowest becomes black and the lightest becomes white.
center_crop(img, output_size) Crops the given image at the center.
convert_image_dtype(image[, dtype]) Convert a tensor image to the given dtype and scale the values accordingly This function does not support PIL Image.
crop(img, top, left, height, width) Crop the given image at specified location and output size.
equalize(img) Equalize the histogram of an image by applying a non-linear mapping to the input in order to create a uniform distribution of grayscale values in the output.
erase(img, i, j, h, w, v[, inplace]) Erase the input Tensor Image with given value.
five_crop(img, size) Crop the given image into four corners and the central crop.
gaussian_blur(img, kernel_size[, sigma]) Performs Gaussian blurring on the image by given kernel
get_dimensions(img) Returns the dimensions of an image as [channels, height, width].
get_image_num_channels(img) Returns the number of channels of an image.
get_image_size(img) Returns the size of an image as [width, height].
hflip(img) Horizontally flip the given image.
invert(img) Invert the colors of an RGB/grayscale image.
normalize(tensor, mean, std[, inplace]) Normalize a float tensor image with mean and standard deviation.
pad(img, padding[, fill, padding_mode]) Pad the given image on all sides with the given "pad" value.
perspective(img, startpoints, endpoints[, ...]) Perform perspective transform of the given image.
pil_to_tensor(pic) Convert a PIL Image to a tensor of the same type.
posterize(img, bits) Posterize an image by reducing the number of bits for each color channel.
resize(img, size[, interpolation, max_size, ...]) Resize the input image to the given size.
resized_crop(img, top, left, height, width, size) Crop the given image and resize it to desired size.
rgb_to_grayscale(img[, num_output_channels]) Convert RGB image to grayscale version of image.
rotate(img, angle[, interpolation, expand, ...]) Rotate the image by angle.
solarize(img, threshold) Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
ten_crop(img, size[, vertical_flip]) Generate ten cropped images from the given image.
to_grayscale(img[, num_output_channels]) Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
to_pil_image(pic[, mode]) Convert a tensor or an ndarray to PIL Image.
to_tensor(pic) Convert a PIL Image or numpy.ndarray to tensor.
vflip(img) Vertically flip the given image.