SigLIP (original) (raw)

SigLIP is a multimodal image-text model similar to CLIP. It uses separate image and text encoders to generate representations for both modalities.

Unlike CLIP, SigLIP employs a pairwise sigmoid loss on image-text pairs during training. This training loss eliminates the need for a global view of all pairwise similarities between images and texts within a batch. Consequently, it enables more efficient scaling to larger batch sizes while also delivering superior performance with smaller batch sizes.

You can find all the original SigLIP checkpoints under the SigLIP collection.

Click on the SigLIP models in the right sidebar for more examples of how to apply SigLIP to different image and text tasks.

The example below demonstrates how to generate similarity scores between texts and image(s) with Pipeline or the AutoModel class.

import torch from transformers import pipeline

image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" candidate_labels = ["a Pallas cat", "a lion", "a Siberian tiger"]

pipeline = pipeline(task="zero-shot-image-classification", model="google/siglip-base-patch16-224", device=0, torch_dtype=torch.bfloat16) pipeline(image, candidate_labels=candidate_labels)

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses bitsandbytes to only quantize the weights to int4.

import torch import requests from PIL import Image from transformers import AutoProcessor, AutoModel, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(load_in_4bit=True) model = AutoModel.from_pretrained("google/siglip-base-patch16-224", quantization_config=bnb_config, device_map="auto", attn_implementation="sdpa") processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" image = Image.open(requests.get(url, stream=True).raw) candidate_labels = ["a Pallas cat", "a lion", "a Siberian tiger"] texts = [f'This is a photo of {label}.' for label in candidate_labels] inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt").to("cuda")

with torch.no_grad(): outputs = model(**inputs)

logits_per_image = outputs.logits_per_image probs = torch.sigmoid(logits_per_image) print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")

Notes

SiglipConfig

class transformers.SiglipConfig

< source >

( text_config = None vision_config = None **kwargs )

Parameters

SiglipConfig is the configuration class to store the configuration of a SiglipModel. It is used to instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglipgoogle/siglip-base-patch16-224 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import SiglipConfig, SiglipModel

configuration = SiglipConfig()

model = SiglipModel(configuration)

configuration = model.config

from transformers import SiglipTextConfig, SiglipVisionConfig

config_text = SiglipTextConfig() config_vision = SiglipVisionConfig()

config = SiglipConfig.from_text_vision_configs(config_text, config_vision)

from_text_vision_configs

< source >

( text_config: SiglipTextConfig vision_config: SiglipVisionConfig **kwargs ) → SiglipConfig

An instance of a configuration object

Instantiate a SiglipConfig (or a derived class) from siglip text model configuration and siglip vision model configuration.

SiglipTextConfig

class transformers.SiglipTextConfig

< source >

( vocab_size = 32000 hidden_size = 768 intermediate_size = 3072 num_hidden_layers = 12 num_attention_heads = 12 max_position_embeddings = 64 hidden_act = 'gelu_pytorch_tanh' layer_norm_eps = 1e-06 attention_dropout = 0.0 pad_token_id = 1 bos_token_id = 49406 eos_token_id = 49407 projection_size = None **kwargs )

Parameters

This is the configuration class to store the configuration of a SiglipTextModel. It is used to instantiate a Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglipgoogle/siglip-base-patch16-224 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import SiglipTextConfig, SiglipTextModel

configuration = SiglipTextConfig()

model = SiglipTextModel(configuration)

configuration = model.config

SiglipVisionConfig

class transformers.SiglipVisionConfig

< source >

( hidden_size = 768 intermediate_size = 3072 num_hidden_layers = 12 num_attention_heads = 12 num_channels = 3 image_size = 224 patch_size = 16 hidden_act = 'gelu_pytorch_tanh' layer_norm_eps = 1e-06 attention_dropout = 0.0 **kwargs )

Parameters

This is the configuration class to store the configuration of a SiglipVisionModel. It is used to instantiate a Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglipgoogle/siglip-base-patch16-224 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import SiglipVisionConfig, SiglipVisionModel

configuration = SiglipVisionConfig()

model = SiglipVisionModel(configuration)

configuration = model.config

SiglipTokenizer

class transformers.SiglipTokenizer

< source >

( vocab_file eos_token = '' unk_token = '' pad_token = '' additional_special_tokens = None sp_model_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None model_max_length = 64 do_lower_case = True **kwargs )

Parameters

Construct a Siglip tokenizer. Based on SentencePiece.

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

build_inputs_with_special_tokens

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of input IDs with the appropriate special tokens.

Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A sequence has the following format:

get_special_tokens_mask

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None already_has_special_tokens: bool = False ) → List[int]

Parameters

A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.

Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer prepare_for_model method.

create_token_type_ids_from_sequences

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of zeros.

Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make use of token type ids, therefore a list of zeros is returned.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

SiglipImageProcessor

class transformers.SiglipImageProcessor

< source >

( do_resize: bool = True size: typing.Optional[typing.Dict[str, int]] = None resample: Resampling = <Resampling.BICUBIC: 3> do_rescale: bool = True rescale_factor: typing.Union[int, float] = 0.00392156862745098 do_normalize: bool = True image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None do_convert_rgb: typing.Optional[bool] = None **kwargs )

Parameters

Constructs a SigLIP image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] do_resize: typing.Optional[bool] = None size: typing.Optional[typing.Dict[str, int]] = None resample: Resampling = None do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_normalize: typing.Optional[bool] = None image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: typing.Optional[transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None do_convert_rgb: typing.Optional[bool] = None )

Parameters

Preprocess an image or batch of images.

SiglipImageProcessorFast

class transformers.SiglipImageProcessorFast

< source >

( **kwargs: typing_extensions.Unpack[transformers.image_processing_utils_fast.DefaultFastImageProcessorKwargs] )

Parameters

Constructs a fast SigLIP image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] **kwargs: typing_extensions.Unpack[transformers.image_processing_utils_fast.DefaultFastImageProcessorKwargs] )

Parameters

Preprocess an image or batch of images.

SiglipProcessor

class transformers.SiglipProcessor

< source >

( image_processor tokenizer )

Parameters

Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.

SiglipProcessor offers all the functionalities of SiglipImageProcessor and SiglipTokenizer. See the__call__() and decode() for more information.

This method forwards all its arguments to SiglipTokenizer’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to SiglipTokenizer’s decode(). Please refer to the docstring of this method for more information.

SiglipModel

class transformers.SiglipModel

< source >

( config: SiglipConfig )

Parameters

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None pixel_values: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None return_loss: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False ) → transformers.models.siglip.modeling_siglip.SiglipOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.siglip.modeling_siglip.SiglipOutput or tuple(torch.FloatTensor)

A transformers.models.siglip.modeling_siglip.SiglipOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.siglip.configuration_siglip.SiglipConfig'>) and inputs.

The SiglipModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, AutoModel import torch

model = AutoModel.from_pretrained("google/siglip-base-patch16-224") processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

texts = ["a photo of 2 cats", "a photo of 2 dogs"]

inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

logits_per_image = outputs.logits_per_image probs = torch.sigmoid(logits_per_image) print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats'

get_text_features

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None ) → text_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

text_features (torch.FloatTensor of shape (batch_size, output_dim)

The text embeddings obtained by applying the projection layer to the pooled output of SiglipTextModel.

The SiglipModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoTokenizer, AutoModel import torch

model = AutoModel.from_pretrained("google/siglip-base-patch16-224") tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") with torch.no_grad(): ... text_features = model.get_text_features(**inputs)

get_image_features

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False ) → image_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

image_features (torch.FloatTensor of shape (batch_size, output_dim)

The image embeddings obtained by applying the projection layer to the pooled output of SiglipVisionModel.

The SiglipModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, AutoModel import torch

model = AutoModel.from_pretrained("google/siglip-base-patch16-224") processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

with torch.no_grad(): ... image_features = model.get_image_features(**inputs)

SiglipTextModel

class transformers.SiglipTextModel

< source >

( config: SiglipTextConfig )

Parameters

The text model from SigLIP without any head or projection on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.siglip.configuration_siglip.SiglipTextConfig'>) and inputs.

The SiglipTextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoTokenizer, SiglipTextModel

model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

SiglipVisionModel

class transformers.SiglipVisionModel

< source >

( config: SiglipVisionConfig )

Parameters

The vision model from SigLIP without any head or projection on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.siglip.configuration_siglip.SiglipVisionConfig'>) and inputs.

The SiglipVisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, SiglipVisionModel

model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

SiglipForImageClassification

class transformers.SiglipForImageClassification

< source >

( config: SiglipConfig )

Parameters

SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False ) → transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.ImageClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SiglipConfig) and inputs.

The SiglipForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoImageProcessor, SiglipForImageClassification import torch from PIL import Image import requests

torch.manual_seed(3) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")

inputs = image_processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) Predicted class: LABEL_1

< > Update on GitHub