Pix2Struct (original) (raw)

PyTorch

Overview

The Pix2Struct model was proposed in Pix2Struct: Screenshot Parsing as Pretraining for Visual Language Understanding by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu, Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova.

The abstract from the paper is the following:

Visually-situated language is ubiquitous — sources range from textbooks with diagrams to web pages with images and tables, to mobile apps with buttons and forms. Perhaps due to this diversity, previous work has typically relied on domain-specific recipes with limited sharing of the underlying data, model architectures, and objectives. We present Pix2Struct, a pretrained image-to-text model for purely visual language understanding, which can be finetuned on tasks containing visually-situated language. Pix2Struct is pretrained by learning to parse masked screenshots of web pages into simplified HTML. The web, with its richness of visual elements cleanly reflected in the HTML structure, provides a large source of pretraining data well suited to the diversity of downstream tasks. Intuitively, this objective subsumes common pretraining signals such as OCR, language modeling, image captioning. In addition to the novel pretraining strategy, we introduce a variable-resolution input representation and a more flexible integration of language and vision inputs, where language prompts such as questions are rendered directly on top of the input image. For the first time, we show that a single pretrained model can achieve state-of-the-art results in six out of nine tasks across four domains: documents, illustrations, user interfaces, and natural images.

Tips:

Pix2Struct has been fine tuned on a variety of tasks and datasets, ranging from image captioning, visual question answering (VQA) over different inputs (books, charts, science diagrams), captioning UI components etc. The full list can be found in Table 1 of the paper. We therefore advise you to use these models for the tasks they have been fine tuned on. For instance, if you want to use Pix2Struct for UI captioning, you should use the model fine tuned on the UI dataset. If you want to use Pix2Struct for image captioning, you should use the model fine tuned on the natural images captioning dataset and so on.

If you want to use the model to perform conditional text captioning, make sure to use the processor with add_special_tokens=False.

This model was contributed by ybelkada. The original code can be found here.

Resources

Pix2StructConfig

class transformers.Pix2StructConfig

< source >

( text_config = None vision_config = None initializer_factor = 1.0 initializer_range = 0.02 is_vqa = False tie_word_embeddings = False is_encoder_decoder = True **kwargs )

Parameters

Pix2StructConfig is the configuration class to store the configuration of aPix2StructForConditionalGeneration. It is used to instantiate a Pix2Struct model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Pix2Struct-basegoogle/pix2struct-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration

configuration = Pix2StructConfig()

model = Pix2StructForConditionalGeneration(configuration)

configuration = model.config

config_text = Pix2StructTextConfig() config_vision = Pix2StructVisionConfig()

config = Pix2StructConfig.from_text_vision_configs(config_text, config_vision)

from_text_vision_configs

< source >

( text_config: Pix2StructTextConfig vision_config: Pix2StructVisionConfig **kwargs ) → Pix2StructConfig

An instance of a configuration object

Instantiate a Pix2StructConfig (or a derived class) from pix2struct text model configuration and pix2struct vision model configuration.

Pix2StructTextConfig

class transformers.Pix2StructTextConfig

< source >

( vocab_size = 50244 hidden_size = 768 d_kv = 64 d_ff = 2048 num_layers = 12 num_heads = 12 relative_attention_num_buckets = 32 relative_attention_max_distance = 128 dropout_rate = 0.1 layer_norm_epsilon = 1e-06 initializer_factor = 1.0 dense_act_fn = 'gelu_new' decoder_start_token_id = 0 use_cache = False pad_token_id = 0 eos_token_id = 1 tie_word_embeddings = False is_decoder = True **kwargs )

Parameters

This is the configuration class to store the configuration of a Pix2StructTextModel. It is used to instantiate a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by the google/pix2struct-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Pix2StructTextConfig, Pix2StructTextModel

configuration = Pix2StructTextConfig()

model = Pix2StructTextModel(configuration)

configuration = model.config

Pix2StructVisionConfig

class transformers.Pix2StructVisionConfig

< source >

( hidden_size = 768 patch_embed_hidden_size = 768 d_ff = 2048 d_kv = 64 num_hidden_layers = 12 num_attention_heads = 12 dense_act_fn = 'gelu_new' layer_norm_eps = 1e-06 dropout_rate = 0.0 attention_dropout = 0.0 initializer_range = 1e-10 initializer_factor = 1.0 seq_len = 4096 relative_attention_num_buckets = 32 relative_attention_max_distance = 128 **kwargs )

Parameters

This is the configuration class to store the configuration of a Pix2StructVisionModel. It is used to instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture. Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-basegoogle/pix2struct-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import Pix2StructVisionConfig, Pix2StructVisionModel

configuration = Pix2StructVisionConfig()

model = Pix2StructVisionModel(configuration)

configuration = model.config

Pix2StructProcessor

class transformers.Pix2StructProcessor

< source >

( image_processor tokenizer )

Parameters

Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single processor.

Pix2StructProcessor offers all the functionalities of Pix2StructImageProcessor and T5TokenizerFast. See the docstring of __call__() and decode() for more information.

This method forwards all its arguments to Pix2StructTokenizerFast’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to Pix2StructTokenizerFast’s decode(). Please refer to the docstring of this method for more information.

Pix2StructImageProcessor

class transformers.Pix2StructImageProcessor

< source >

( do_convert_rgb: bool = True do_normalize: bool = True patch_size: typing.Optional[dict[str, int]] = None max_patches: int = 2048 is_vqa: bool = False **kwargs )

Parameters

Constructs a Pix2Struct image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] header_text: typing.Optional[str] = None do_convert_rgb: typing.Optional[bool] = None do_normalize: typing.Optional[bool] = None max_patches: typing.Optional[int] = None patch_size: typing.Optional[dict[str, int]] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: ChannelDimension = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None **kwargs )

Parameters

Preprocess an image or batch of images. The processor first computes the maximum possible number of aspect-ratio preserving patches of size patch_size that can be extracted from the image. It then pads the image with zeros to make the image respect the constraint of max_patches. Before extracting the patches the images are standardized following the tensorflow implementation of per_image_standardization(https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization).

Pix2StructTextModel

class transformers.Pix2StructTextModel

< source >

( config )

Parameters

The standalone text decoder of Pix2Struct

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None encoder_hidden_states: typing.Optional[torch.FloatTensor] = None encoder_attention_mask: typing.Optional[torch.FloatTensor] = None inputs_embeds: typing.Optional[torch.LongTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None cross_attn_head_mask: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[tuple[tuple[torch.FloatTensor]]] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None labels: typing.Optional[torch.LongTensor] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None **kwargs ) → transformers.modeling_outputs.CausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Pix2StructConfig) and inputs.

The Pix2StructTextModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, Pix2StructTextModel

processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base")

inputs = processor(text="Hello, my dog is cute", return_tensors="pt") outputs = model(**inputs) loss = outputs.loss

Pix2StructVisionModel

class transformers.Pix2StructVisionModel

< source >

( config: Pix2StructConfig )

Parameters

The bare Pix2Struct Model outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( flattened_patches: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutputWithPooling or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Pix2StructConfig) and inputs.

The Pix2StructVisionModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

import requests from PIL import Image from transformers import AutoProcessor, Pix2StructVisionModel

image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")

url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = image_processor(images=image, return_tensors="pt") with torch.no_grad(): ... outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state list(last_hidden_states.shape) [1, 2048, 768]

Pix2StructForConditionalGeneration

class transformers.Pix2StructForConditionalGeneration

< source >

( config: Pix2StructConfig )

Parameters

A conditional generation model with a language modeling head. Can be used for sequence generation tasks.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( flattened_patches: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None decoder_input_ids: typing.Optional[torch.LongTensor] = None decoder_attention_mask: typing.Optional[torch.BoolTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None decoder_head_mask: typing.Optional[torch.FloatTensor] = None cross_attn_head_mask: typing.Optional[torch.Tensor] = None encoder_outputs: typing.Optional[tuple[tuple[torch.FloatTensor]]] = None past_key_values: typing.Optional[tuple[tuple[torch.FloatTensor]]] = None labels: typing.Optional[torch.LongTensor] = None decoder_inputs_embeds: typing.Optional[torch.Tensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None ) → transformers.modeling_outputs.Seq2SeqModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.Seq2SeqModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Pix2StructConfig) and inputs.

The Pix2StructForConditionalGeneration forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

Inference:

from PIL import Image import requests from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")

url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

generated_ids = model.generate(**inputs, max_new_tokens=50) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_text) A stop sign is on a street corner.

text = "A picture of" inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)

generated_ids = model.generate(**inputs, max_new_tokens=50) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_text) A picture of a stop sign with a red stop sign

Training:

from PIL import Image import requests from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained("google/pix2struct-base") model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")

url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) text = "A stop sign is on the street corner."

inputs = processor(images=image, return_tensors="pt") labels = processor(text=text, return_tensors="pt").input_ids

outputs = model(**inputs, labels=labels) loss = outputs.loss print(f"{loss.item():.5f}") 5.94282

< > Update on GitHub