Swin2SR (original) (raw)

PyTorch

Overview

The Swin2SR model was proposed in Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte. Swin2SR improves the SwinIR model by incorporating Swin Transformer v2 layers which mitigates issues such as training instability, resolution gaps between pre-training and fine-tuning, and hunger on data.

The abstract from the paper is the following:

Compression plays an important role on the efficient transmission and storage of images and videos through band-limited systems such as streaming services, virtual reality or videogames. However, compression unavoidably leads to artifacts and the loss of the original information, which may severely degrade the visual quality. For these reasons, quality enhancement of compressed images has become a popular research topic. While most state-of-the-art image restoration methods are based on convolutional neural networks, other transformers-based methods such as SwinIR, show impressive performance on these tasks. In this paper, we explore the novel Swin Transformer V2, to improve SwinIR for image super-resolution, and in particular, the compressed input scenario. Using this method we can tackle the major issues in training transformer vision models, such as training instability, resolution gaps between pre-training and fine-tuning, and hunger on data. We conduct experiments on three representative tasks: JPEG compression artifacts removal, image super-resolution (classical and lightweight), and compressed image super-resolution. Experimental results demonstrate that our method, Swin2SR, can improve the training convergence and performance of SwinIR, and is a top-5 solution at the “AIM 2022 Challenge on Super-Resolution of Compressed Image and Video”.

drawing Swin2SR architecture. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Resources

Demo notebooks for Swin2SR can be found here.

A demo Space for image super-resolution with SwinSR can be found here.

Swin2SRImageProcessor

class transformers.Swin2SRImageProcessor

< source >

( do_rescale: bool = True rescale_factor: typing.Union[int, float] = 0.00392156862745098 do_pad: bool = True pad_size: int = 8 **kwargs )

Parameters

Constructs a Swin2SR image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_pad: typing.Optional[bool] = None pad_size: typing.Optional[int] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: typing.Union[str, transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None )

Parameters

Preprocess an image or batch of images.

Swin2SRConfig

class transformers.Swin2SRConfig

< source >

( image_size = 64 patch_size = 1 num_channels = 3 num_channels_out = None embed_dim = 180 depths = [6, 6, 6, 6, 6, 6] num_heads = [6, 6, 6, 6, 6, 6] window_size = 8 mlp_ratio = 2.0 qkv_bias = True hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 drop_path_rate = 0.1 hidden_act = 'gelu' use_absolute_embeddings = False initializer_range = 0.02 layer_norm_eps = 1e-05 upscale = 2 img_range = 1.0 resi_connection = '1conv' upsampler = 'pixelshuffle' **kwargs )

Parameters

This is the configuration class to store the configuration of a Swin2SRModel. It is used to instantiate a Swin Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2caidas/swin2sr-classicalsr-x2-64 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Swin2SRConfig, Swin2SRModel

configuration = Swin2SRConfig()

model = Swin2SRModel(configuration)

configuration = model.config

Swin2SRModel

class transformers.Swin2SRModel

< source >

( config )

Parameters

The bare Swin2SR Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: FloatTensor head_mask: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Swin2SRConfig) and inputs.

The Swin2SRModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoImageProcessor, Swin2SRModel import torch from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") model = Swin2SRModel.from_pretrained("caidas/swin2SR-classical-sr-x2-64")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state list(last_hidden_states.shape) [1, 180, 488, 648]

Swin2SRForImageSuperResolution

class transformers.Swin2SRForImageSuperResolution

< source >

( config )

Parameters

Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.ImageSuperResolutionOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.modeling_outputs.ImageSuperResolutionOutput or tuple(torch.FloatTensor)

A transformers.modeling_outputs.ImageSuperResolutionOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Swin2SRConfig) and inputs.

The Swin2SRForImageSuperResolution forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

import torch import numpy as np from PIL import Image import requests

from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution

processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")

url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(image, return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8)

< > Update on GitHub