OWLv2 (original) (raw)

Overview

OWLv2 was proposed in Scaling Open-Vocabulary Object Detection by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. OWLv2 scales up OWL-ViT using self-training, which uses an existing detector to generate pseudo-box annotations on image-text pairs. This results in large gains over the previous state-of-the-art for zero-shot object detection.

The abstract from the paper is the following:

Open-vocabulary object detection has benefited greatly from pretrained vision-language models, but is still limited by the amount of available detection training data. While detection training data can be expanded by using Web image-text pairs as weak supervision, this has not been done at scales comparable to image-level pretraining. Here, we scale up detection data with self-training, which uses an existing detector to generate pseudo-box annotations on image-text pairs. Major challenges in scaling self-training are the choice of label space, pseudo-annotation filtering, and training efficiency. We present the OWLv2 model and OWL-ST self-training recipe, which address these challenges. OWLv2 surpasses the performance of previous state-of-the-art open-vocabulary detectors already at comparable training scales (~10M examples). However, with OWL-ST, we can scale to over 1B examples, yielding further large improvement: With an L/14 architecture, OWL-ST improves AP on LVIS rare classes, for which the model has seen no human box annotations, from 31.2% to 44.6% (43% relative improvement). OWL-ST unlocks Web-scale training for open-world localization, similar to what has been seen for image classification and language modelling.

drawing OWLv2 high-level overview. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Usage example

OWLv2 is, just like its predecessor OWL-ViT, a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a lightweight classification and box head to each transformer output token. Open-vocabulary classification is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image can be used to perform zero-shot text-conditioned object detection.

Owlv2ImageProcessor can be used to resize (or rescale) and normalize images for the model and CLIPTokenizer is used to encode the text. Owlv2Processor wraps Owlv2ImageProcessor and CLIPTokenizer into a single instance to both encode the text and prepare the images. The following example shows how to perform object detection using Owlv2Processor and Owlv2ForObjectDetection.

import requests from PIL import Image import torch

from transformers import Owlv2Processor, Owlv2ForObjectDetection

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) texts = [["a photo of a cat", "a photo of a dog"]] inputs = processor(text=texts, images=image, return_tensors="pt") outputs = model(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1) i = 0
text = texts[i] boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] for box, score, label in zip(boxes, scores, labels): ... box = [round(i, 2) for i in box.tolist()] ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35] Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]

Resources

The architecture of OWLv2 is identical to OWL-ViT, however the object detection head now also includes an objectness classifier, which predicts the (query-agnostic) likelihood that a predicted box contains an object (as opposed to background). The objectness score can be used to rank or filter predictions independently of text queries. Usage of OWLv2 is identical to OWL-ViT with a new, updated image processor (Owlv2ImageProcessor).

Owlv2Config

class transformers.Owlv2Config

< source >

( text_config = None vision_config = None projection_dim = 512 logit_scale_init_value = 2.6592 return_dict = True **kwargs )

Parameters

Owlv2Config is the configuration class to store the configuration of an Owlv2Model. It is used to instantiate an OWLv2 model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWLv2google/owlv2-base-patch16 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

from_text_vision_configs

< source >

( text_config: Dict vision_config: Dict **kwargs ) → Owlv2Config

An instance of a configuration object

Instantiate a Owlv2Config (or a derived class) from owlv2 text model configuration and owlv2 vision model configuration.

Owlv2TextConfig

class transformers.Owlv2TextConfig

< source >

( vocab_size = 49408 hidden_size = 512 intermediate_size = 2048 num_hidden_layers = 12 num_attention_heads = 8 max_position_embeddings = 16 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 pad_token_id = 0 bos_token_id = 49406 eos_token_id = 49407 **kwargs )

Parameters

This is the configuration class to store the configuration of an Owlv2TextModel. It is used to instantiate an Owlv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Owlv2google/owlv2-base-patch16 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Owlv2TextConfig, Owlv2TextModel

configuration = Owlv2TextConfig()

model = Owlv2TextModel(configuration)

configuration = model.config

Owlv2VisionConfig

class transformers.Owlv2VisionConfig

< source >

( hidden_size = 768 intermediate_size = 3072 num_hidden_layers = 12 num_attention_heads = 12 num_channels = 3 image_size = 768 patch_size = 16 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 **kwargs )

Parameters

This is the configuration class to store the configuration of an Owlv2VisionModel. It is used to instantiate an OWLv2 image encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWLv2google/owlv2-base-patch16 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Owlv2VisionConfig, Owlv2VisionModel

configuration = Owlv2VisionConfig()

model = Owlv2VisionModel(configuration)

configuration = model.config

Owlv2ImageProcessor

class transformers.Owlv2ImageProcessor

< source >

( do_rescale: bool = True rescale_factor: Union = 0.00392156862745098 do_pad: bool = True do_resize: bool = True size: Dict = None resample: Resampling = <Resampling.BILINEAR: 2> do_normalize: bool = True image_mean: Union = None image_std: Union = None **kwargs )

Parameters

Constructs an OWLv2 image processor.

preprocess

< source >

( images: Union do_pad: bool = None do_resize: bool = None size: Dict = None do_rescale: bool = None rescale_factor: float = None do_normalize: bool = None image_mean: Union = None image_std: Union = None return_tensors: Union = None data_format: ChannelDimension = <ChannelDimension.FIRST: 'channels_first'> input_data_format: Union = None )

Parameters

Preprocess an image or batch of images.

post_process_object_detection

< source >

( outputs threshold: float = 0.1 target_sizes: Union = None ) → List[Dict]

Parameters

A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model.

Converts the raw output of OwlViTForObjectDetection into final bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format.

post_process_image_guided_detection

< source >

( outputs threshold = 0.0 nms_threshold = 0.3 target_sizes = None ) → List[Dict]

Parameters

A list of dictionaries, each dictionary containing the scores, labels and boxes for an image in the batch as predicted by the model. All labels are set to None asOwlViTForObjectDetection.image_guided_detection perform one-shot object detection.

Converts the output of OwlViTForObjectDetection.image_guided_detection() into the format expected by the COCO api.

Owlv2Processor

class transformers.Owlv2Processor

< source >

( image_processor tokenizer **kwargs )

Parameters

Constructs an Owlv2 processor which wraps Owlv2ImageProcessor and CLIPTokenizer/CLIPTokenizerFast into a single processor that interits both the image processor and tokenizer functionalities. See the__call__() and decode() for more information.

This method forwards all its arguments to CLIPTokenizerFast’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to CLIPTokenizerFast’s decode(). Please refer to the docstring of this method for more information.

post_process_image_guided_detection

< source >

( *args **kwargs )

This method forwards all its arguments to OwlViTImageProcessor.post_process_one_shot_object_detection. Please refer to the docstring of this method for more information.

Owlv2Model

class transformers.Owlv2Model

< source >

( config: Owlv2Config )

Parameters

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: Optional = None pixel_values: Optional = None attention_mask: Optional = None return_loss: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_base_image_embeds: Optional = None return_dict: Optional = None ) → transformers.models.owlv2.modeling_owlv2.Owlv2Output or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.owlv2.modeling_owlv2.Owlv2Output or tuple(torch.FloatTensor)

A transformers.models.owlv2.modeling_owlv2.Owlv2Output or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlv2.configuration_owlv2.Owlv2Config'>) and inputs.

The Owlv2Model forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, Owlv2Model

model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble") processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") outputs = model(**inputs) logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)

get_text_features

< source >

( input_ids: Optional = None attention_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → text_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

text_features (torch.FloatTensor of shape (batch_size, output_dim)

The text embeddings obtained by applying the projection layer to the pooled output of Owlv2TextModel.

The Owlv2Model forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, Owlv2Model

model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble") processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") inputs = processor( ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" ... ) text_features = model.get_text_features(**inputs)

get_image_features

< source >

( pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → image_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

image_features (torch.FloatTensor of shape (batch_size, output_dim)

The image embeddings obtained by applying the projection layer to the pooled output of Owlv2VisionModel.

The Owlv2Model forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, Owlv2Model

model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble") processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(images=image, return_tensors="pt") image_features = model.get_image_features(**inputs)

Owlv2TextModel

class transformers.Owlv2TextModel

< source >

( config: Owlv2TextConfig )

forward

< source >

( input_ids: Tensor attention_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlv2.configuration_owlv2.Owlv2TextConfig'>) and inputs.

The Owlv2TextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, Owlv2TextModel

model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16") processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") inputs = processor( ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" ... ) outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

Owlv2VisionModel

class transformers.Owlv2VisionModel

< source >

( config: Owlv2VisionConfig )

forward

< source >

( pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlv2.configuration_owlv2.Owlv2VisionConfig'>) and inputs.

The Owlv2VisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, Owlv2VisionModel

model = Owlv2VisionModel.from_pretrained("google/owlv2-base-patch16") processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

Owlv2ForObjectDetection

class transformers.Owlv2ForObjectDetection

< source >

( config: Owlv2Config )

forward

< source >

( input_ids: Tensor pixel_values: FloatTensor attention_mask: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.models.owlv2.modeling_owlv2.Owlv2ObjectDetectionOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.owlv2.modeling_owlv2.Owlv2ObjectDetectionOutput or tuple(torch.FloatTensor)

A transformers.models.owlv2.modeling_owlv2.Owlv2ObjectDetectionOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlv2.configuration_owlv2.Owlv2Config'>) and inputs.

The Owlv2ForObjectDetection forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import requests from PIL import Image import torch from transformers import AutoProcessor, Owlv2ForObjectDetection

processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) texts = [["a photo of a cat", "a photo of a dog"]] inputs = processor(text=texts, images=image, return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_object_detection( ... outputs=outputs, threshold=0.2, target_sizes=target_sizes ... )

i = 0
text = texts[i] boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]

for box, score, label in zip(boxes, scores, labels): ... box = [round(i, 2) for i in box.tolist()] ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35] Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]

image_guided_detection

< source >

( pixel_values: FloatTensor query_pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.models.owlv2.modeling_owlv2.Owlv2ImageGuidedObjectDetectionOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.owlv2.modeling_owlv2.Owlv2ImageGuidedObjectDetectionOutput or tuple(torch.FloatTensor)

A transformers.models.owlv2.modeling_owlv2.Owlv2ImageGuidedObjectDetectionOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.owlv2.configuration_owlv2.Owlv2Config'>) and inputs.

The Owlv2ForObjectDetection forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import requests from PIL import Image import torch from transformers import AutoProcessor, Owlv2ForObjectDetection

processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) query_url = "http://images.cocodataset.org/val2017/000000001675.jpg" query_image = Image.open(requests.get(query_url, stream=True).raw) inputs = processor(images=image, query_images=query_image, return_tensors="pt")

with torch.no_grad(): ... outputs = model.image_guided_detection(**inputs)

target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process_image_guided_detection( ... outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes ... ) i = 0
boxes, scores = results[i]["boxes"], results[i]["scores"] for box, score in zip(boxes, scores): ... box = [round(i, 2) for i in box.tolist()] ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}") Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06] Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39] Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8] Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83] Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82] Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05] Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01] Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72] Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18] Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21] Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76] Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07] Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]

< > Update on GitHub