FocalNet (original) (raw)

PyTorch

Overview

The FocalNet model was proposed in Focal Modulation Networks by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao. FocalNets completely replace self-attention (used in models like ViT and Swin) by a focal modulation mechanism for modeling token interactions in vision. The authors claim that FocalNets outperform self-attention based models with similar computational costs on the tasks of image classification, object detection, and segmentation.

The abstract from the paper is the following:

We propose focal modulation networks (FocalNets in short), where self-attention (SA) is completely replaced by a focal modulation mechanism for modeling token interactions in vision. Focal modulation comprises three components: (i) hierarchical contextualization, implemented using a stack of depth-wise convolutional layers, to encode visual contexts from short to long ranges, (ii) gated aggregation to selectively gather contexts for each query token based on its content, and (iii) element-wise modulation or affine transformation to inject the aggregated context into the query. Extensive experiments show FocalNets outperform the state-of-the-art SA counterparts (e.g., Swin and Focal Transformers) with similar computational costs on the tasks of image classification, object detection, and segmentation. Specifically, FocalNets with tiny and base size achieve 82.3% and 83.9% top-1 accuracy on ImageNet-1K. After pretrained on ImageNet-22K in 224 resolution, it attains 86.5% and 87.3% top-1 accuracy when finetuned with resolution 224 and 384, respectively. When transferred to downstream tasks, FocalNets exhibit clear superiority. For object detection with Mask R-CNN, FocalNet base trained with 1\times outperforms the Swin counterpart by 2.1 points and already surpasses Swin trained with 3\times schedule (49.0 v.s. 48.5). For semantic segmentation with UPerNet, FocalNet base at single-scale outperforms Swin by 2.4, and beats Swin at multi-scale (50.5 v.s. 49.7). Using large FocalNet and Mask2former, we achieve 58.5 mIoU for ADE20K semantic segmentation, and 57.9 PQ for COCO Panoptic Segmentation. Using huge FocalNet and DINO, we achieved 64.3 and 64.4 mAP on COCO minival and test-dev, respectively, establishing new SoTA on top of much larger attention-based models like Swinv2-G and BEIT-3.

This model was contributed by nielsr. The original code can be found here.

FocalNetConfig

class transformers.FocalNetConfig

< source >

( image_size = 224 patch_size = 4 num_channels = 3 embed_dim = 96 use_conv_embed = False hidden_sizes = [192, 384, 768, 768] depths = [2, 2, 6, 2] focal_levels = [2, 2, 2, 2] focal_windows = [3, 3, 3, 3] hidden_act = 'gelu' mlp_ratio = 4.0 hidden_dropout_prob = 0.0 drop_path_rate = 0.1 use_layerscale = False layerscale_value = 0.0001 use_post_layernorm = False use_post_layernorm_in_modulation = False normalize_modulator = False initializer_range = 0.02 layer_norm_eps = 1e-05 encoder_stride = 32 out_features = None out_indices = None **kwargs )

Parameters

This is the configuration class to store the configuration of a FocalNetModel. It is used to instantiate a FocalNet model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the FocalNetmicrosoft/focalnet-tiny architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import FocalNetConfig, FocalNetModel

configuration = FocalNetConfig()

model = FocalNetModel(configuration)

configuration = model.config

FocalNetModel

class transformers.FocalNetModel

< source >

( config add_pooling_layer = True use_mask_token = False )

Parameters

The bare FocalNet Model outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None bool_masked_pos: typing.Optional[torch.BoolTensor] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.models.focalnet.modeling_focalnet.FocalNetModelOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.focalnet.modeling_focalnet.FocalNetModelOutput or tuple(torch.FloatTensor)

A transformers.models.focalnet.modeling_focalnet.FocalNetModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (FocalNetConfig) and inputs.

The FocalNetModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoImageProcessor, FocalNetModel import torch from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny") model = FocalNetModel.from_pretrained("microsoft/focalnet-tiny")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state list(last_hidden_states.shape) [1, 49, 768]

FocalNetForMaskedImageModeling

class transformers.FocalNetForMaskedImageModeling

< source >

( config )

Parameters

FocalNet Model with a decoder on top for masked image modeling.

This follows the same implementation as in SimMIM.

Note that we provide a script to pre-train this model on custom data in our examples directory.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None bool_masked_pos: typing.Optional[torch.BoolTensor] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.models.focalnet.modeling_focalnet.FocalNetMaskedImageModelingOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.focalnet.modeling_focalnet.FocalNetMaskedImageModelingOutput or tuple(torch.FloatTensor)

A transformers.models.focalnet.modeling_focalnet.FocalNetMaskedImageModelingOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (FocalNetConfig) and inputs.

The FocalNetForMaskedImageModeling forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling import torch from PIL import Image import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192") config = FocalNetConfig() model = FocalNetForMaskedImageModeling(config)

num_patches = (model.config.image_size // model.config.patch_size) ** 2 pixel_values = image_processor(images=image, return_tensors="pt").pixel_values

bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) loss, reconstructed_pixel_values = outputs.loss, outputs.logits list(reconstructed_pixel_values.shape) [1, 3, 192, 192]

FocalNetForImageClassification

class transformers.FocalNetForImageClassification

< source >

( config )

Parameters

FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for ImageNet.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.models.focalnet.modeling_focalnet.FocalNetImageClassifierOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.focalnet.modeling_focalnet.FocalNetImageClassifierOutput or tuple(torch.FloatTensor)

A transformers.models.focalnet.modeling_focalnet.FocalNetImageClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (FocalNetConfig) and inputs.

The FocalNetForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoImageProcessor, FocalNetForImageClassification import torch from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny") model = FocalNetForImageClassification.from_pretrained("microsoft/focalnet-tiny")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_label = logits.argmax(-1).item() print(model.config.id2label[predicted_label]) tabby, tabby cat

< > Update on GitHub