GIT (original) (raw)

PyTorch

Overview

The GIT model was proposed in GIT: A Generative Image-to-text Transformer for Vision and Language by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. GIT is a decoder-only Transformer that leverages CLIP’s vision encoder to condition the model on vision inputs besides text. The model obtains state-of-the-art results on image captioning and visual question answering benchmarks.

The abstract from the paper is the following:

In this paper, we design and train a Generative Image-to-text Transformer, GIT, to unify vision-language tasks such as image/video captioning and question answering. While generative models provide a consistent network architecture between pre-training and fine-tuning, existing work typically contains complex structures (uni/multi-modal encoder/decoder) and depends on external modules such as object detectors/taggers and optical character recognition (OCR). In GIT, we simplify the architecture as one image encoder and one text decoder under a single language modeling task. We also scale up the pre-training data and the model size to boost the model performance. Without bells and whistles, our GIT establishes new state of the arts on 12 challenging benchmarks with a large margin. For instance, our model surpasses the human performance for the first time on TextCaps (138.2 vs. 125.5 in CIDEr). Furthermore, we present a new scheme of generation-based image classification and scene text recognition, achieving decent performance on standard benchmarks.

drawing GIT architecture. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Usage tips

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GIT.

If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we will review it. The resource should ideally demonstrate something new instead of duplicating an existing resource.

GitVisionConfig

class transformers.GitVisionConfig

< source >

( hidden_size = 768 intermediate_size = 3072 num_hidden_layers = 12 num_attention_heads = 12 num_channels = 3 image_size = 224 patch_size = 16 hidden_act = 'quick_gelu' layer_norm_eps = 1e-05 attention_dropout = 0.0 initializer_range = 0.02 **kwargs )

Parameters

This is the configuration class to store the configuration of a GitVisionModel. It is used to instantiate a GIT vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the GITmicrosoft/git-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import GitVisionConfig, GitVisionModel

configuration = GitVisionConfig()

model = GitVisionModel(configuration)

configuration = model.config

GitVisionModel

class transformers.GitVisionModel

< source >

( config: GitVisionConfig )

Parameters

The vision model from CLIP, used in GIT, without any head or projection on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GitConfig) and inputs.

The GitVisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from PIL import Image import requests from transformers import AutoProcessor, GitVisionModel

processor = AutoProcessor.from_pretrained("microsoft/git-base") model = GitVisionModel.from_pretrained("microsoft/git-base")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state

GitConfig

class transformers.GitConfig

< source >

( vision_config = None vocab_size = 30522 hidden_size = 768 num_hidden_layers = 6 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout_prob = 0.1 attention_probs_dropout_prob = 0.1 max_position_embeddings = 1024 initializer_range = 0.02 layer_norm_eps = 1e-12 pad_token_id = 0 position_embedding_type = 'absolute' use_cache = True tie_word_embeddings = False bos_token_id = 101 eos_token_id = 102 num_image_with_embedding = None **kwargs )

Parameters

This is the configuration class to store the configuration of a GitModel. It is used to instantiate a GIT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the GITmicrosoft/git-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Examples:

from transformers import GitConfig, GitModel

configuration = GitConfig()

model = GitModel(configuration)

configuration = model.config

GitProcessor

class transformers.GitProcessor

< source >

( image_processor tokenizer )

Parameters

Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.

GitProcessor offers all the functionalities of CLIPImageProcessor and BertTokenizerFast. See thecall() and decode() for more information.

__call__

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor'], NoneType] = None text: typing.Union[str, typing.List[str], typing.List[typing.List[str]], NoneType] = None audio = None videos = None **kwargs: typing_extensions.Unpack[transformers.models.git.processing_git.GitProcessorKwargs] ) → BatchFeature

Parameters

A BatchFeature with the following fields:

Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the textand kwargs arguments to BertTokenizerFast’s call() if text is not None to encode the text. To prepare the image(s), this method forwards the images and kwrags arguments to CLIPImageProcessor’s call() if images is not None. Please refer to the docstring of the above two methods for more information.

GitModel

class transformers.GitModel

< source >

( config )

Parameters

The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None pixel_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None past_key_values: typing.Union[typing.List[torch.FloatTensor], transformers.cache_utils.Cache, NoneType] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GitConfig) and inputs.

The GitModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, AutoModel import requests from PIL import Image

processor = AutoProcessor.from_pretrained("microsoft/git-base") model = AutoModel.from_pretrained("microsoft/git-base")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

text = "this is an image of two cats"

inputs = processor(images=image, text=text, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state

GitForCausalLM

class transformers.GitForCausalLM

< source >

( config )

Parameters

GIT Model with a language modeling head on top for autoregressive language modeling.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None pixel_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None past_key_values: typing.Union[transformers.cache_utils.Cache, typing.List[torch.Tensor], NoneType] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False return_dict: typing.Optional[bool] = None **kwargs ) → transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutputWithPast or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GitConfig) and inputs.

The GitForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

Image captioning example:

from transformers import AutoProcessor, AutoModelForCausalLM import requests from PIL import Image

processor = AutoProcessor.from_pretrained("microsoft/git-base-coco") model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

pixel_values = processor(images=image, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values=pixel_values, max_length=50) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_caption) two cats sleeping on a pink blanket next to remotes.

Visual question answering (VQA) example:

from transformers import AutoProcessor, AutoModelForCausalLM from huggingface_hub import hf_hub_download from PIL import Image

processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa") model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")

file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset") image = Image.open(file_path).convert("RGB")

pixel_values = processor(images=image, return_tensors="pt").pixel_values

question = "what does the front of the bus say at the top?"

input_ids = processor(text=question, add_special_tokens=False).input_ids input_ids = [processor.tokenizer.cls_token_id] + input_ids input_ids = torch.tensor(input_ids).unsqueeze(0)

generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) print(processor.batch_decode(generated_ids, skip_special_tokens=True)) ['what does the front of the bus say at the top? special']

Video captioning example:

import av import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from transformers import AutoProcessor, AutoModelForCausalLM

processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex") model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")

np.random.seed(45)

def read_video_pyav(container, indices): ... ''' ... Decode the video with PyAV decoder. ... Args: ... container (av.container.input.InputContainer): PyAV container. ... indices (List[int]): List of frame indices to decode. ... Returns: ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). ... ''' ... frames = [] ... container.seek(0) ... start_index = indices[0] ... end_index = indices[-1] ... for i, frame in enumerate(container.decode(video=0)): ... if i > end_index: ... break ... if i >= start_index and i in indices: ... frames.append(frame) ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len): ... ''' ... Sample a given number of frame indices from the video. ... Args: ... clip_len (int): Total number of frames to sample. ... frame_sample_rate (int): Sample every n-th frame. ... seg_len (int): Maximum allowed index of sample's last frame. ... Returns: ... indices (List[int]): List of sampled frame indices ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len ... indices = np.linspace(start_idx, end_idx, num=clip_len) ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) ... return indices

file_path = hf_hub_download( ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" ... ) container = av.open(file_path)

num_frames = model.config.num_image_with_embedding indices = sample_frame_indices( ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames ... ) frames = read_video_pyav(container, indices)

pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values=pixel_values, max_length=50)

print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True)) Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']

< > Update on GitHub