X-CLIP (original) (raw)

Overview

The X-CLIP model was proposed in Expanding Language-Image Pretrained Models for General Video Recognition by Bolin Ni, Houwen Peng, Minghao Chen, Songyang Zhang, Gaofeng Meng, Jianlong Fu, Shiming Xiang, Haibin Ling. X-CLIP is a minimal extension of CLIP for video. The model consists of a text encoder, a cross-frame vision encoder, a multi-frame integration Transformer, and a video-specific prompt generator.

The abstract from the paper is the following:

Contrastive language-image pretraining has shown great success in learning visual-textual joint representation from web-scale data, demonstrating remarkable β€œzero-shot” generalization ability for various image tasks. However, how to effectively expand such new language-image pretraining methods to video domains is still an open problem. In this work, we present a simple yet effective approach that adapts the pretrained language-image models to video recognition directly, instead of pretraining a new model from scratch. More concretely, to capture the long-range dependencies of frames along the temporal dimension, we propose a cross-frame attention mechanism that explicitly exchanges information across frames. Such module is lightweight and can be plugged into pretrained language-image models seamlessly. Moreover, we propose a video-specific prompting scheme, which leverages video content information for generating discriminative textual prompts. Extensive experiments demonstrate that our approach is effective and can be generalized to different video recognition scenarios. In particular, under fully-supervised settings, our approach achieves a top-1 accuracy of 87.1% on Kinectics-400, while using 12 times fewer FLOPs compared with Swin-L and ViViT-H. In zero-shot experiments, our approach surpasses the current state-of-the-art methods by +7.6% and +14.9% in terms of top-1 accuracy under two popular protocols. In few-shot scenarios, our approach outperforms previous best methods by +32.1% and +23.1% when the labeled data is extremely limited.

Tips:

drawing X-CLIP architecture. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with X-CLIP.

If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

XCLIPProcessor

class transformers.XCLIPProcessor

< source >

( image_processor = None tokenizer = None **kwargs )

Parameters

Constructs an X-CLIP processor which wraps a VideoMAE image processor and a CLIP tokenizer into a single processor.

XCLIPProcessor offers all the functionalities of VideoMAEImageProcessor and CLIPTokenizerFast. See the__call__() and decode() for more information.

This method forwards all its arguments to CLIPTokenizerFast’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to CLIPTokenizerFast’s decode(). Please refer to the docstring of this method for more information.

XCLIPConfig

class transformers.XCLIPConfig

< source >

( text_config = None vision_config = None projection_dim = 512 prompt_layers = 2 prompt_alpha = 0.1 prompt_hidden_act = 'quick_gelu' prompt_num_attention_heads = 8 prompt_attention_dropout = 0.0 prompt_projection_dropout = 0.0 logit_scale_init_value = 2.6592 **kwargs )

Parameters

XCLIPConfig is the configuration class to store the configuration of a XCLIPModel. It is used to instantiate X-CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the X-CLIPmicrosoft/xclip-base-patch32 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

from_text_vision_configs

< source >

( text_config: XCLIPTextConfig vision_config: XCLIPVisionConfig **kwargs ) β†’ XCLIPConfig

An instance of a configuration object

Instantiate a XCLIPConfig (or a derived class) from xclip text model configuration and xclip vision model configuration.

XCLIPTextConfig

class transformers.XCLIPTextConfig

< source >

( vocab_size = 49408 hidden_size = 512 intermediate_size = 2048 num_hidden_layers = 12 num_attention_heads = 8 max_position_embeddings = 77 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 pad_token_id = 1 bos_token_id = 0 eos_token_id = 2 **kwargs )

Parameters

This is the configuration class to store the configuration of a XCLIPModel. It is used to instantiate an X-CLIP model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the X-CLIPmicrosoft/xclip-base-patch32 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import XCLIPTextModel, XCLIPTextConfig

configuration = XCLIPTextConfig()

model = XCLIPTextModel(configuration)

configuration = model.config

XCLIPVisionConfig

class transformers.XCLIPVisionConfig

< source >

( hidden_size = 768 intermediate_size = 3072 num_hidden_layers = 12 num_attention_heads = 12 mit_hidden_size = 512 mit_intermediate_size = 2048 mit_num_hidden_layers = 1 mit_num_attention_heads = 8 num_channels = 3 image_size = 224 patch_size = 32 num_frames = 8 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 initializer_factor = 1.0 drop_path_rate = 0.0 **kwargs )

Parameters

This is the configuration class to store the configuration of a XCLIPModel. It is used to instantiate an X-CLIP model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the X-CLIPmicrosoft/xclip-base-patch32 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import XCLIPVisionModel, XCLIPVisionConfig

configuration = XCLIPVisionConfig()

model = XCLIPVisionModel(configuration)

configuration = model.config

XCLIPModel

class transformers.XCLIPModel

< source >

( config: XCLIPConfig )

Parameters

This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: Optional = None pixel_values: Optional = None attention_mask: Optional = None position_ids: Optional = None return_loss: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.models.x_clip.modeling_x_clip.XCLIPOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.x_clip.modeling_x_clip.XCLIPOutput or tuple(torch.FloatTensor)

A transformers.models.x_clip.modeling_x_clip.XCLIPOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.x_clip.configuration_x_clip.XCLIPConfig'>) and inputs.

The XCLIPModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import av import torch import numpy as np

from transformers import AutoProcessor, AutoModel from huggingface_hub import hf_hub_download

np.random.seed(0)

def read_video_pyav(container, indices): ... ''' ... Decode the video with PyAV decoder. ... Args: ... container (av.container.input.InputContainer): PyAV container. ... indices (List[int]): List of frame indices to decode. ... Returns: ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). ... ''' ... frames = [] ... container.seek(0) ... start_index = indices[0] ... end_index = indices[-1] ... for i, frame in enumerate(container.decode(video=0)): ... if i > end_index: ... break ... if i >= start_index and i in indices: ... frames.append(frame) ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... ''' ... Sample a given number of frame indices from the video. ... Args: ... clip_len (int): Total number of frames to sample. ... frame_sample_rate (int): Sample every n-th frame. ... seg_len (int): Maximum allowed index of sample's last frame. ... Returns: ... indices (List[int]): List of sampled frame indices ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len ... indices = np.linspace(start_idx, end_idx, num=clip_len) ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) ... return indices

file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) container = av.open(file_path)

indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) video = read_video_pyav(container, indices)

processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32") model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")

inputs = processor( ... text=["playing sports", "eating spaghetti", "go shopping"], ... videos=list(video), ... return_tensors="pt", ... padding=True, ... )

with torch.no_grad(): ... outputs = model(**inputs)

logits_per_video = outputs.logits_per_video
probs = logits_per_video.softmax(dim=1)
print(probs) tensor([[1.9496e-04, 9.9960e-01, 2.0825e-04]])

get_text_features

< source >

( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ text_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

text_features (torch.FloatTensor of shape (batch_size, output_dim)

The text embeddings obtained by applying the projection layer to the pooled output of XCLIPTextModel.

The XCLIPModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32") model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") text_features = model.get_text_features(**inputs)

get_video_features

< source >

( pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ video_features (torch.FloatTensor of shape (batch_size, output_dim)

Parameters

Returns

video_features (torch.FloatTensor of shape (batch_size, output_dim)

The video embeddings obtained by applying the projection layer to the pooled output of XCLIPVisionModel andXCLIPMultiframeIntegrationTransformer.

The XCLIPModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import av import torch import numpy as np

from transformers import AutoProcessor, AutoModel from huggingface_hub import hf_hub_download

np.random.seed(0)

def read_video_pyav(container, indices): ... ''' ... Decode the video with PyAV decoder. ... Args: ... container (av.container.input.InputContainer): PyAV container. ... indices (List[int]): List of frame indices to decode. ... Returns: ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). ... ''' ... frames = [] ... container.seek(0) ... start_index = indices[0] ... end_index = indices[-1] ... for i, frame in enumerate(container.decode(video=0)): ... if i > end_index: ... break ... if i >= start_index and i in indices: ... frames.append(frame) ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... ''' ... Sample a given number of frame indices from the video. ... Args: ... clip_len (int): Total number of frames to sample. ... frame_sample_rate (int): Sample every n-th frame. ... seg_len (int): Maximum allowed index of sample's last frame. ... Returns: ... indices (List[int]): List of sampled frame indices ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len ... indices = np.linspace(start_idx, end_idx, num=clip_len) ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) ... return indices

file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) container = av.open(file_path)

indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) video = read_video_pyav(container, indices)

processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32") model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")

inputs = processor(videos=list(video), return_tensors="pt")

video_features = model.get_video_features(**inputs)

XCLIPTextModel

class transformers.XCLIPTextModel

< source >

( config: XCLIPTextConfig )

forward

< source >

( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.x_clip.configuration_x_clip.XCLIPTextConfig'>) and inputs.

The XCLIPTextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoTokenizer, XCLIPTextModel

model = XCLIPTextModel.from_pretrained("microsoft/xclip-base-patch32") tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output

XCLIPVisionModel

class transformers.XCLIPVisionModel

< source >

( config: XCLIPVisionConfig )

forward

< source >

( pixel_values: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.x_clip.configuration_x_clip.XCLIPVisionConfig'>) and inputs.

The XCLIPVisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

import av import torch import numpy as np

from transformers import AutoProcessor, XCLIPVisionModel from huggingface_hub import hf_hub_download

np.random.seed(0)

def read_video_pyav(container, indices): ... ''' ... Decode the video with PyAV decoder. ... Args: ... container (av.container.input.InputContainer): PyAV container. ... indices (List[int]): List of frame indices to decode. ... Returns: ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). ... ''' ... frames = [] ... container.seek(0) ... start_index = indices[0] ... end_index = indices[-1] ... for i, frame in enumerate(container.decode(video=0)): ... if i > end_index: ... break ... if i >= start_index and i in indices: ... frames.append(frame) ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... ''' ... Sample a given number of frame indices from the video. ... Args: ... clip_len (int): Total number of frames to sample. ... frame_sample_rate (int): Sample every n-th frame. ... seg_len (int): Maximum allowed index of sample's last frame. ... Returns: ... indices (List[int]): List of sampled frame indices ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len ... indices = np.linspace(start_idx, end_idx, num=clip_len) ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) ... return indices

file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) container = av.open(file_path)

indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) video = read_video_pyav(container, indices)

processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32") model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")

pixel_values = processor(videos=list(video), return_tensors="pt").pixel_values

batch_size, num_frames, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(-1, num_channels, height, width)

outputs = model(pixel_values) last_hidden_state = outputs.last_hidden_state

< > Update on GitHub