CLIP (original) (raw)

This model was released on 2021-02-26 and added to Hugging Face Transformers on 2021-05-12.

CLIP is a is a multimodal vision and language model motivated by overcoming the fixed number of object categories when training a computer vision model. CLIP learns about images directly from raw text by jointly training on 400M (image, text) pairs. Pretraining on this scale enables zero-shot transfer to downstream tasks. CLIP uses an image encoder and text encoder to get visual features and text features. Both features are projected to a latent space with the same number of dimensions and their dot product gives a similarity score.

You can find all the original CLIP checkpoints under the OpenAI organization.

Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.

The example below demonstrates how to calculate similarity scores between multiple text descriptions and an image with Pipeline or the AutoModel class.

import torch from transformers import pipeline

clip = pipeline( task="zero-shot-image-classification", model="openai/clip-vit-base-patch32", dtype=torch.bfloat16, device=0 ) labels = ["a photo of a cat", "a photo of a dog", "a photo of a car"] clip("http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=labels)

Notes

CLIPConfig

class transformers.CLIPConfig

< source >

( text_config = None vision_config = None projection_dim = 512 logit_scale_init_value = 2.6592 **kwargs )

Parameters

CLIPConfig is the configuration class to store the configuration of a CLIPModel. It is used to instantiate a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the CLIPopenai/clip-vit-base-patch32 architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Example:

from transformers import CLIPConfig, CLIPModel

configuration = CLIPConfig()

model = CLIPModel(configuration)

configuration = model.config

from transformers import CLIPTextConfig, CLIPVisionConfig

config_text = CLIPTextConfig() config_vision = CLIPVisionConfig()

config = CLIPConfig(text_config=config_text, vision_config=config_vision)

CLIPTextConfig

class transformers.CLIPTextConfig

< source >

( vocab_size = 49408 hidden_size = 512 intermediate_size = 2048 projection_dim = 512 num_hidden_layers = 12 num_attention_heads = 8 max_position_embeddings = 77 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 pad_token_id = 1 bos_token_id = 49406 eos_token_id = 49407 **kwargs )

Parameters

This is the configuration class to store the configuration of a CLIPTextModel. It is used to instantiate a CLIP text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the CLIPopenai/clip-vit-base-patch32 architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Example:

from transformers import CLIPTextConfig, CLIPTextModel

configuration = CLIPTextConfig()

model = CLIPTextModel(configuration)

configuration = model.config

CLIPVisionConfig

class transformers.CLIPVisionConfig

< source >

( hidden_size = 768 intermediate_size = 3072 projection_dim = 512 num_hidden_layers = 12 num_attention_heads = 12 num_channels = 3 image_size = 224 patch_size = 32 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 **kwargs )

Parameters

This is the configuration class to store the configuration of a CLIPVisionModel. It is used to instantiate a CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIPopenai/clip-vit-base-patch32 architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Example:

from transformers import CLIPVisionConfig, CLIPVisionModel

configuration = CLIPVisionConfig()

model = CLIPVisionModel(configuration)

configuration = model.config

CLIPTokenizer

class transformers.CLIPTokenizer

< source >

( vocab: typing.Union[str, dict[str, int], NoneType] = None merges: typing.Union[str, list[str], NoneType] = None unk_token: str = '<|endoftext|>' bos_token: str = '<|startoftext|>' eos_token: str = '<|endoftext|>' pad_token: str = '<|endoftext|>' **kwargs )

Parameters

Construct a CLIP tokenizer (backed by HuggingFace’s tokenizers library). Based on byte-level Byte-Pair-Encoding.

This tokenizer inherits from TokenizersBackend which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

get_special_tokens_mask

< source >

( token_ids_0: list[int] token_ids_1: Optional[list[int]] = None already_has_special_tokens: bool = False ) → A list of integers in the range [0, 1]

Parameters

Returns

A list of integers in the range [0, 1]

1 for a special token, 0 for a sequence token.

Retrieve sequence ids from a token list that has no special tokens added.

For fast tokenizers, data collators call this with already_has_special_tokens=True to build a mask over an already-formatted sequence. In that case, we compute the mask by checking membership in all_special_ids.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

CLIPTokenizerFast

class transformers.CLIPTokenizer

< source >

( vocab: typing.Union[str, dict[str, int], NoneType] = None merges: typing.Union[str, list[str], NoneType] = None unk_token: str = '<|endoftext|>' bos_token: str = '<|startoftext|>' eos_token: str = '<|endoftext|>' pad_token: str = '<|endoftext|>' **kwargs )

Parameters

Construct a CLIP tokenizer (backed by HuggingFace’s tokenizers library). Based on byte-level Byte-Pair-Encoding.

This tokenizer inherits from TokenizersBackend which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

CLIPImageProcessor

class transformers.CLIPImageProcessor

< source >

( do_resize: bool = True size: typing.Optional[dict[str, int]] = None resample: Resampling = <Resampling.BICUBIC: 3> do_center_crop: bool = True crop_size: typing.Optional[dict[str, int]] = None do_rescale: bool = True rescale_factor: typing.Union[int, float] = 0.00392156862745098 do_normalize: bool = True image_mean: typing.Union[float, list[float], NoneType] = None image_std: typing.Union[float, list[float], NoneType] = None do_convert_rgb: bool = True **kwargs )

Parameters

Constructs a CLIP image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] do_resize: typing.Optional[bool] = None size: typing.Optional[dict[str, int]] = None resample: typing.Optional[PIL.Image.Resampling] = None do_center_crop: typing.Optional[bool] = None crop_size: typing.Optional[int] = None do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_normalize: typing.Optional[bool] = None image_mean: typing.Union[float, list[float], NoneType] = None image_std: typing.Union[float, list[float], NoneType] = None do_convert_rgb: typing.Optional[bool] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: typing.Optional[transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None **kwargs )

Parameters

Preprocess an image or batch of images.

CLIPImageProcessorFast

class transformers.CLIPImageProcessorFast

< source >

( **kwargs: typing_extensions.Unpack[transformers.processing_utils.ImagesKwargs] )

Constructs a fast Clip image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] *args **kwargs: typing_extensions.Unpack[transformers.processing_utils.ImagesKwargs] ) → <class 'transformers.image_processing_base.BatchFeature'>

Parameters

Returns

<class 'transformers.image_processing_base.BatchFeature'>

CLIPProcessor

class transformers.CLIPProcessor

< source >

( image_processor = None tokenizer = None **kwargs )

Parameters

Constructs a CLIP processor which wraps a CLIP image processor and a CLIP tokenizer into a single processor.

CLIPProcessor offers all the functionalities of CLIPImageProcessor and CLIPTokenizerFast. See thecall() and decode() for more information.

CLIPModel

class transformers.CLIPModel

< source >

( config: CLIPConfig )

Parameters

The bare Clip Model outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None pixel_values: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None return_loss: typing.Optional[bool] = None interpolate_pos_encoding: bool = False **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.models.clip.modeling_clip.CLIPOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.clip.modeling_clip.CLIPOutput or tuple(torch.FloatTensor)

A transformers.models.clip.modeling_clip.CLIPOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (CLIPConfig) and inputs.

The CLIPModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import torch from transformers import AutoProcessor, CLIPModel from transformers.image_utils import load_image

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = load_image(url)

inputs = processor( ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True ... )

with torch.inference_mode(): ... outputs = model(**inputs) logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)

get_text_features

< source >

( input_ids: Tensor attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None ) → text_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

text_features (torch.FloatTensor of shape (batch_size, output_dim)

The text embeddings obtained by applying the projection layer to the pooled output of CLIPTextModel.

Examples:

import torch from transformers import AutoTokenizer, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

with torch.inference_mode(): ... text_features = model.get_text_features(**inputs)

get_image_features

< source >

( pixel_values: FloatTensor interpolate_pos_encoding: bool = False ) → image_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

image_features (torch.FloatTensor of shape (batch_size, output_dim)

The image embeddings obtained by applying the projection layer to the pooled output of CLIPVisionModel.

Examples:

import torch from transformers import AutoProcessor, CLIPModel from transformers.image_utils import load_image

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = load_image(url)

inputs = processor(images=image, return_tensors="pt")

with torch.inference_mode(): ... image_features = model.get_image_features(**inputs)

CLIPTextModel

class transformers.CLIPTextModel

< source >

( config: CLIPTextConfig )

Parameters

The text model from CLIP without any head or projection on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (CLIPConfig) and inputs.

The CLIPTextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoTokenizer, CLIPTextModel

model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

CLIPTextModelWithProjection

class transformers.CLIPTextModelWithProjection

< source >

( config: CLIPTextConfig )

Parameters

The Clip Model with a projection layer on top (a linear layer on top of the pooled output).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.models.clip.modeling_clip.CLIPTextModelOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.clip.modeling_clip.CLIPTextModelOutput or tuple(torch.FloatTensor)

A transformers.models.clip.modeling_clip.CLIPTextModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (CLIPConfig) and inputs.

The CLIPTextModelWithProjection forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection

model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

with torch.inference_mode(): ... outputs = model(**inputs) text_embeds = outputs.text_embeds

CLIPVisionModelWithProjection

class transformers.CLIPVisionModelWithProjection

< source >

( config: CLIPVisionConfig )

Parameters

The Clip Model with a projection layer on top (a linear layer on top of the pooled output).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None interpolate_pos_encoding: bool = False **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.models.clip.modeling_clip.CLIPVisionModelOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.clip.modeling_clip.CLIPVisionModelOutput or tuple(torch.FloatTensor)

A transformers.models.clip.modeling_clip.CLIPVisionModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (CLIPConfig) and inputs.

The CLIPVisionModelWithProjection forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import torch from transformers import AutoProcessor, CLIPVisionModelWithProjection from transformers.image_utils import load_image

model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = load_image(url)

inputs = processor(images=image, return_tensors="pt")

with torch.inference_mode(): ... outputs = model(**inputs) image_embeds = outputs.image_embeds

CLIPVisionModel

class transformers.CLIPVisionModel

< source >

( config: CLIPVisionConfig )

Parameters

The vision model from CLIP without any head or projection on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None interpolate_pos_encoding: bool = False **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (CLIPConfig) and inputs.

The CLIPVisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from PIL import Image import requests from transformers import AutoProcessor, CLIPVisionModel

model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

CLIPForImageClassification

class transformers.CLIPForImageClassification

< source >

( config: CLIPConfig )

Parameters

CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) → transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.ImageClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (CLIPConfig) and inputs.

The CLIPForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoImageProcessor, CLIPForImageClassification import torch from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image") image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch32") model = CLIPForImageClassification.from_pretrained("openai/clip-vit-base-patch32")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_label = logits.argmax(-1).item() print(model.config.id2label[predicted_label]) ...

Update on GitHub