OWL-ViT (original) (raw)

Overview

The OWL-ViT (short for Vision Transformer for Open-World Localization) was proposed in Simple Open-Vocabulary Object Detection with Vision Transformers by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. OWL-ViT is an open-vocabulary object detection network trained on a variety of (image, text) pairs. It can be used to query an image with one or multiple text queries to search for and detect target objects described in text.

The abstract from the paper is the following:

Combining simple architectures with large-scale pre-training has led to massive improvements in image classification. For object detection, pre-training and scaling approaches are less well established, especially in the long-tailed and open-vocabulary setting, where training data is relatively scarce. In this paper, we propose a strong recipe for transferring image-text models to open-vocabulary object detection. We use a standard Vision Transformer architecture with minimal modifications, contrastive image-text pre-training, and end-to-end detection fine-tuning. Our analysis of the scaling properties of this setup shows that increasing image-level pre-training and model size yield consistent improvements on the downstream detection task. We provide the adaptation strategies and regularizations needed to attain very strong performance on zero-shot text-conditioned and one-shot image-conditioned object detection. Code and models are available on GitHub.

drawing OWL-ViT architecture. Taken from the original paper.

This model was contributed by adirik. The original code can be found here.

Usage tips

OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a lightweight classification and box head to each transformer output token. Open-vocabulary classification is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image can be used to perform zero-shot text-conditioned object detection.

OwlViTImageProcessor can be used to resize (or rescale) and normalize images for the model and CLIPTokenizer is used to encode the text. OwlViTProcessor wraps OwlViTImageProcessor and CLIPTokenizer into a single instance to both encode the text and prepare the images. The following example shows how to perform object detection using OwlViTProcessor and OwlViTForObjectDetection.

import requests from PIL import Image import torch

from transformers import OwlViTProcessor, OwlViTForObjectDetection

processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) texts = [["a photo of a cat", "a photo of a dog"]] inputs = processor(text=texts, images=image, return_tensors="pt") outputs = model(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1) i = 0
text = texts[i] boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] for box, score, label in zip(boxes, scores, labels): ... box = [round(i, 2) for i in box.tolist()] ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]

Resources

A demo notebook on using OWL-ViT for zero- and one-shot (image-guided) object detection can be found here.

OwlViTConfig

class transformers.OwlViTConfig

< source >

( text_config = None vision_config = None projection_dim = 512 logit_scale_init_value = 2.6592 return_dict = True **kwargs )

Parameters

OwlViTConfig is the configuration class to store the configuration of an OwlViTModel. It is used to instantiate an OWL-ViT model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWL-ViTgoogle/owlvit-base-patch32 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

from_text_vision_configs

< source >

( text_config: Dict vision_config: Dict **kwargs ) → OwlViTConfig

An instance of a configuration object

Instantiate a OwlViTConfig (or a derived class) from owlvit text model configuration and owlvit vision model configuration.

OwlViTTextConfig

class transformers.OwlViTTextConfig

< source >

( vocab_size = 49408 hidden_size = 512 intermediate_size = 2048 num_hidden_layers = 12 num_attention_heads = 8 max_position_embeddings = 16 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 pad_token_id = 0 bos_token_id = 49406 eos_token_id = 49407 **kwargs )

Parameters

This is the configuration class to store the configuration of an OwlViTTextModel. It is used to instantiate an OwlViT text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the OwlViTgoogle/owlvit-base-patch32 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import OwlViTTextConfig, OwlViTTextModel

configuration = OwlViTTextConfig()

model = OwlViTTextModel(configuration)

configuration = model.config

OwlViTVisionConfig

class transformers.OwlViTVisionConfig

< source >

( hidden_size = 768 intermediate_size = 3072 num_hidden_layers = 12 num_attention_heads = 12 num_channels = 3 image_size = 768 patch_size = 32 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 **kwargs )

Parameters

This is the configuration class to store the configuration of an OwlViTVisionModel. It is used to instantiate an OWL-ViT image encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWL-ViTgoogle/owlvit-base-patch32 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import OwlViTVisionConfig, OwlViTVisionModel

configuration = OwlViTVisionConfig()

model = OwlViTVisionModel(configuration)

configuration = model.config

OwlViTImageProcessor

class transformers.OwlViTImageProcessor

< source >

( do_resize = True size = None resample = <Resampling.BICUBIC: 3> do_center_crop = False crop_size = None do_rescale = True rescale_factor = 0.00392156862745098 do_normalize = True image_mean = None image_std = None **kwargs )

Parameters

Constructs an OWL-ViT image processor.

This image processor inherits from ImageProcessingMixin which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

preprocess

< source >

( images: Union do_resize: Optional = None size: Optional = None resample: Resampling = None do_center_crop: Optional = None crop_size: Optional = None do_rescale: Optional = None rescale_factor: Optional = None do_normalize: Optional = None image_mean: Union = None image_std: Union = None return_tensors: Union = None data_format: Union = <ChannelDimension.FIRST: 'channels_first'> input_data_format: Union = None )

Parameters

Prepares an image or batch of images for the model.

post_process_object_detection

< source >

( outputs threshold: float = 0.1 target_sizes: Union = None ) → List[Dict]

Parameters

A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model.

Converts the raw output of OwlViTForObjectDetection into final bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.

post_process_image_guided_detection

< source >

( outputs threshold = 0.0 nms_threshold = 0.3 target_sizes = None ) → List[Dict]

Parameters

A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. All labels are set to None asOwlViTForObjectDetection.image_guided_detection perform one-shot object detection.

Converts the output of OwlViTForObjectDetection.image_guided_detection() into the format expected by the COCO api.

OwlViTFeatureExtractor

Preprocess an image or a batch of images.

( outputs target_sizes ) → List[Dict]

Parameters

A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model.

Converts the raw output of OwlViTForObjectDetection into final bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.

( outputs threshold = 0.0 nms_threshold = 0.3 target_sizes = None ) → List[Dict]

Parameters

A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. All labels are set to None asOwlViTForObjectDetection.image_guided_detection perform one-shot object detection.

Converts the output of OwlViTForObjectDetection.image_guided_detection() into the format expected by the COCO api.

OwlViTProcessor

class transformers.OwlViTProcessor

< source >

( image_processor = None tokenizer = None **kwargs )

Parameters

Constructs an OWL-ViT processor which wraps OwlViTImageProcessor and CLIPTokenizer/CLIPTokenizerFastinto a single processor that interits both the image processor and tokenizer functionalities. See the__call__() and decode() for more information.

This method forwards all its arguments to CLIPTokenizerFast’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to CLIPTokenizerFast’s decode(). Please refer to the docstring of this method for more information.

post_process_image_guided_detection

< source >

( *args **kwargs )

This method forwards all its arguments to OwlViTImageProcessor.post_process_one_shot_object_detection. Please refer to the docstring of this method for more information.

OwlViTModel

class transformers.OwlViTModel

< source >

( config: OwlViTConfig )

Parameters

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: Optional = None pixel_values: Optional = None attention_mask: Optional = None return_loss: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_base_image_embeds: Optional = None return_dict: Optional = None ) → transformers.models.owlvit.modeling_owlvit.OwlViTOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.owlvit.modeling_owlvit.OwlViTOutput or tuple(torch.FloatTensor)

A transformers.models.owlvit.modeling_owlvit.OwlViTOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlvit.configuration_owlvit.OwlViTConfig'>) and inputs.

The OwlViTModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, OwlViTModel

model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") outputs = model(**inputs) logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)

get_text_features

< source >

( input_ids: Optional = None attention_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → text_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

text_features (torch.FloatTensor of shape (batch_size, output_dim)

The text embeddings obtained by applying the projection layer to the pooled output of OwlViTTextModel.

The OwlViTModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, OwlViTModel

model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") inputs = processor( ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" ... ) text_features = model.get_text_features(**inputs)

get_image_features

< source >

( pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → image_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

image_features (torch.FloatTensor of shape (batch_size, output_dim)

The image embeddings obtained by applying the projection layer to the pooled output of OwlViTVisionModel.

The OwlViTModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, OwlViTModel

model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(images=image, return_tensors="pt") image_features = model.get_image_features(**inputs)

OwlViTTextModel

class transformers.OwlViTTextModel

< source >

( config: OwlViTTextConfig )

forward

< source >

( input_ids: Tensor attention_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlvit.configuration_owlvit.OwlViTTextConfig'>) and inputs.

The OwlViTTextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, OwlViTTextModel

model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32") processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") inputs = processor( ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" ... ) outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

OwlViTVisionModel

class transformers.OwlViTVisionModel

< source >

( config: OwlViTVisionConfig )

forward

< source >

( pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlvit.configuration_owlvit.OwlViTVisionConfig'>) and inputs.

The OwlViTVisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, OwlViTVisionModel

model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32") processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

OwlViTForObjectDetection

class transformers.OwlViTForObjectDetection

< source >

( config: OwlViTConfig )

forward

< source >

( input_ids: Tensor pixel_values: FloatTensor attention_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.models.owlvit.modeling_owlvit.OwlViTObjectDetectionOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.owlvit.modeling_owlvit.OwlViTObjectDetectionOutput or tuple(torch.FloatTensor)

A transformers.models.owlvit.modeling_owlvit.OwlViTObjectDetectionOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlvit.configuration_owlvit.OwlViTConfig'>) and inputs.

The OwlViTForObjectDetection forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import requests from PIL import Image import torch from transformers import AutoProcessor, OwlViTForObjectDetection

processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) texts = [["a photo of a cat", "a photo of a dog"]] inputs = processor(text=texts, images=image, return_tensors="pt") outputs = model(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_object_detection( ... outputs=outputs, threshold=0.1, target_sizes=target_sizes ... )

i = 0
text = texts[i] boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]

for box, score, label in zip(boxes, scores, labels): ... box = [round(i, 2) for i in box.tolist()] ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]

image_guided_detection

< source >

( pixel_values: FloatTensor query_pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.models.owlvit.modeling_owlvit.OwlViTImageGuidedObjectDetectionOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.owlvit.modeling_owlvit.OwlViTImageGuidedObjectDetectionOutput or tuple(torch.FloatTensor)

A transformers.models.owlvit.modeling_owlvit.OwlViTImageGuidedObjectDetectionOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlvit.configuration_owlvit.OwlViTConfig'>) and inputs.

The OwlViTForObjectDetection forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import requests from PIL import Image import torch from transformers import AutoProcessor, OwlViTForObjectDetection

processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) query_url = "http://images.cocodataset.org/val2017/000000001675.jpg" query_image = Image.open(requests.get(query_url, stream=True).raw) inputs = processor(images=image, query_images=query_image, return_tensors="pt") with torch.no_grad(): ... outputs = model.image_guided_detection(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_image_guided_detection( ... outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes ... ) i = 0
boxes, scores = results[i]["boxes"], results[i]["scores"] for box, score in zip(boxes, scores): ... box = [round(i, 2) for i in box.tolist()] ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}") Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39] Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71]

< > Update on GitHub