UDOP (original) (raw)

PyTorch

Overview

The UDOP model was proposed in Unifying Vision, Text, and Layout for Universal Document Processing by Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. UDOP adopts an encoder-decoder Transformer architecture based on T5 for document AI tasks like document image classification, document parsing and document visual question answering.

The abstract from the paper is the following:

We propose Universal Document Processing (UDOP), a foundation Document AI model which unifies text, image, and layout modalities together with varied task formats, including document understanding and generation. UDOP leverages the spatial correlation between textual content and document image to model image, text, and layout modalities with one uniform representation. With a novel Vision-Text-Layout Transformer, UDOP unifies pretraining and multi-domain downstream tasks into a prompt-based sequence generation scheme. UDOP is pretrained on both large-scale unlabeled document corpora using innovative self-supervised objectives and diverse labeled data. UDOP also learns to generate document images from text and layout modalities via masked image reconstruction. To the best of our knowledge, this is the first time in the field of document AI that one model simultaneously achieves high-quality neural document editing and content customization. Our method sets the state-of-the-art on 9 Document AI tasks, e.g., document understanding and QA, across diverse data domains like finance reports, academic papers, and websites. UDOP ranks first on the leaderboard of the Document Understanding Benchmark (DUE).*

drawing UDOP architecture. Taken from the original paper.

Usage tips

def normalize_bbox(bbox, width, height): return [ int(1000 * (bbox[0] / width)), int(1000 * (bbox[1] / height)), int(1000 * (bbox[2] / width)), int(1000 * (bbox[3] / height)), ]

Here, width and height correspond to the width and height of the original document in which the token occurs. Those can be obtained using the Python Image Library (PIL) library for example, as follows:

from PIL import Image

image = Image.open(name_of_your_document).convert("RGB")

width, height = image.size

One can use UdopProcessor to prepare images and text for the model, which takes care of all of this. By default, this class uses the Tesseract engine to extract a list of words and boxes (coordinates) from a given document. Its functionality is equivalent to that of LayoutLMv3Processor, hence it supports passing either apply_ocr=False in case you prefer to use your own OCR engine or apply_ocr=True in case you want the default OCR engine to be used. Refer to the usage guide of LayoutLMv2 regarding all possible use cases (the functionality of UdopProcessor is identical).

This model was contributed by nielsr. The original code can be found here.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with UDOP. If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

UdopConfig

class transformers.UdopConfig

< source >

( vocab_size = 33201 d_model = 1024 d_kv = 64 d_ff = 4096 num_layers = 24 num_decoder_layers = None num_heads = 16 relative_attention_num_buckets = 32 relative_attention_max_distance = 128 relative_bias_args = [{'type': '1d'}, {'type': 'horizontal'}, {'type': 'vertical'}] dropout_rate = 0.1 layer_norm_epsilon = 1e-06 initializer_factor = 1.0 feed_forward_proj = 'relu' is_encoder_decoder = True use_cache = True pad_token_id = 0 eos_token_id = 1 max_2d_position_embeddings = 1024 image_size = 224 patch_size = 16 num_channels = 3 **kwargs )

Parameters

This is the configuration class to store the configuration of a UdopForConditionalGeneration. It is used to instantiate a UDOP model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the UDOPmicrosoft/udop-large architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

UdopTokenizer

class transformers.UdopTokenizer

< source >

( vocab_file eos_token = '' unk_token = '' sep_token = '' pad_token = '' sep_token_box = [1000, 1000, 1000, 1000] pad_token_box = [0, 0, 0, 0] pad_token_label = -100 only_label_first_subword = True additional_special_tokens = None sp_model_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None legacy = True add_prefix_space = True **kwargs )

Parameters

Adapted from LayoutXLMTokenizer and T5Tokenizer. Based onSentencePiece.

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

build_inputs_with_special_tokens

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of input IDs with the appropriate special tokens.

Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A sequence has the following format:

get_special_tokens_mask

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None already_has_special_tokens: bool = False ) → List[int]

Parameters

A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.

Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer prepare_for_model method.

create_token_type_ids_from_sequences

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of zeros.

Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make use of token type ids, therefore a list of zeros is returned.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

UdopTokenizerFast

class transformers.UdopTokenizerFast

< source >

( vocab_file = None tokenizer_file = None eos_token = '' sep_token = '' unk_token = '' pad_token = '' sep_token_box = [1000, 1000, 1000, 1000] pad_token_box = [0, 0, 0, 0] pad_token_label = -100 only_label_first_subword = True additional_special_tokens = None **kwargs )

Parameters

Construct a “fast” UDOP tokenizer (backed by HuggingFace’s tokenizers library). Adapted fromLayoutXLMTokenizer and T5Tokenizer. Based onBPE.

This tokenizer inherits from PreTrainedTokenizerFast which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

batch_encode_plus_boxes

< source >

( batch_text_or_text_pairs: typing.Union[typing.List[str], typing.List[typing.Tuple[str, str]], typing.List[typing.List[str]]] is_pair: typing.Optional[bool] = None boxes: typing.Optional[typing.List[typing.List[typing.List[int]]]] = None word_labels: typing.Optional[typing.List[typing.List[int]]] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = None max_length: typing.Optional[int] = None stride: int = 0 is_split_into_words: bool = False pad_to_multiple_of: typing.Optional[int] = None padding_side: typing.Optional[str] = None return_tensors: typing.Union[transformers.utils.generic.TensorType, str, NoneType] = None return_token_type_ids: typing.Optional[bool] = None return_attention_mask: typing.Optional[bool] = None return_overflowing_tokens: bool = False return_special_tokens_mask: bool = False return_offsets_mapping: bool = False return_length: bool = False verbose: bool = True **kwargs )

Parameters

Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.

This method is deprecated, __call__ should be used instead.

build_inputs_with_special_tokens

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of input IDs with the appropriate special tokens.

Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. An XLM-RoBERTa sequence has the following format:

call_boxes

< source >

( text: typing.Union[str, typing.List[str], typing.List[typing.List[str]]] text_pair: typing.Union[typing.List[str], typing.List[typing.List[str]], NoneType] = None boxes: typing.Union[typing.List[typing.List[int]], typing.List[typing.List[typing.List[int]]], NoneType] = None word_labels: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = None max_length: typing.Optional[int] = None stride: int = 0 pad_to_multiple_of: typing.Optional[int] = None padding_side: typing.Optional[str] = None return_tensors: typing.Union[transformers.utils.generic.TensorType, str, NoneType] = None return_token_type_ids: typing.Optional[bool] = None return_attention_mask: typing.Optional[bool] = None return_overflowing_tokens: bool = False return_special_tokens_mask: bool = False return_offsets_mapping: bool = False return_length: bool = False verbose: bool = True **kwargs ) → BatchEncoding

Parameters

A BatchEncoding with the following fields:

Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of sequences with word-level normalized bounding boxes and optional labels.

create_token_type_ids_from_sequences

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of zeros.

Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does not make use of token type ids, therefore a list of zeros is returned.

encode_boxes

< source >

( text: typing.Union[str, typing.List[str], typing.List[int]] text_pair: typing.Union[str, typing.List[str], typing.List[int], NoneType] = None boxes: typing.Optional[typing.List[typing.List[int]]] = None word_labels: typing.Optional[typing.List[typing.List[int]]] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = None max_length: typing.Optional[int] = None stride: int = 0 return_tensors: typing.Union[transformers.utils.generic.TensorType, str, NoneType] = None **kwargs )

Parameters

encode_plus_boxes

< source >

( text: typing.Union[str, typing.List[str]] text_pair: typing.Optional[typing.List[str]] = None boxes: typing.Optional[typing.List[typing.List[int]]] = None word_labels: typing.Optional[typing.List[typing.List[int]]] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = None max_length: typing.Optional[int] = None stride: int = 0 is_split_into_words: bool = False pad_to_multiple_of: typing.Optional[int] = None padding_side: typing.Optional[str] = None return_tensors: typing.Union[transformers.utils.generic.TensorType, str, NoneType] = None return_token_type_ids: typing.Optional[bool] = None return_attention_mask: typing.Optional[bool] = None return_overflowing_tokens: bool = False return_special_tokens_mask: bool = False return_offsets_mapping: bool = False return_length: bool = False verbose: bool = True **kwargs )

Parameters

Tokenize and prepare for the model a sequence or a pair of sequences.

This method is deprecated, __call__ should be used instead.

UdopProcessor

class transformers.UdopProcessor

< source >

( image_processor tokenizer )

Parameters

Constructs a UDOP processor which combines a LayoutLMv3 image processor and a UDOP tokenizer into a single processor.

UdopProcessor offers all the functionalities you need to prepare data for the model.

It first uses LayoutLMv3ImageProcessor to resize, rescale and normalize document images, and optionally applies OCR to get words and normalized bounding boxes. These are then provided to UdopTokenizer or UdopTokenizerFast, which turns the words and bounding boxes into token-level input_ids, attention_mask, token_type_ids, bbox. Optionally, one can provide integer word_labels, which are turned into token-level labels for token classification tasks (such as FUNSD, CORD).

Additionally, it also supports passing text_target and text_pair_target to the tokenizer, which can be used to prepare labels for language modeling tasks.

__call__

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor'], NoneType] = None text: typing.Union[str, typing.List[str], typing.List[typing.List[str]]] = None *args audio = None videos = None **kwargs: typing_extensions.Unpack[transformers.models.udop.processing_udop.UdopProcessorKwargs] )

This method first forwards the images argument to ~UdopImageProcessor.__call__. In caseUdopImageProcessor was initialized with apply_ocr set to True, it passes the obtained words and bounding boxes along with the additional arguments to __call__() and returns the output, together with the prepared pixel_values. In case UdopImageProcessor was initialized with apply_ocr set to False, it passes the words (text/` text_pair) and boxes specified by the user along with the additional arguments to __call__() and returns the output, together with the preparedpixel_values.

Alternatively, one can pass text_target and text_pair_target to prepare the targets of UDOP.

Please refer to the docstring of the above two methods for more information.

UdopModel

class transformers.UdopModel

< source >

( config )

Parameters

The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None bbox: typing.Optional[typing.Dict[str, typing.Any]] = None pixel_values: typing.Optional[torch.Tensor] = None visual_bbox: typing.Optional[typing.Dict[str, typing.Any]] = None decoder_input_ids: typing.Optional[torch.Tensor] = None decoder_attention_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_outputs: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None decoder_inputs_embeds: typing.Optional[torch.Tensor] = None decoder_head_mask: typing.Optional[torch.Tensor] = None cross_attn_head_mask: typing.Optional[torch.Tensor] = None use_cache = True output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None ) → transformers.modeling_outputs.Seq2SeqModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.Seq2SeqModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (UdopConfig) and inputs.

The UdopModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, AutoModel from datasets import load_dataset import torch

processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) model = AutoModel.from_pretrained("microsoft/udop-large")

dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True) example = dataset[0] image = example["image"] words = example["tokens"] boxes = example["bboxes"] inputs = processor(image, words, boxes=boxes, return_tensors="pt")

decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]])

outputs = model(**inputs, decoder_input_ids=decoder_input_ids) last_hidden_states = outputs.last_hidden_state list(last_hidden_states.shape) [1, 1, 1024]

UdopForConditionalGeneration

class transformers.UdopForConditionalGeneration

< source >

( config )

Parameters

The UDOP encoder-decoder Transformer with a language modeling head on top, enabling to generate text given document images and an optional prompt.

This class is based on T5ForConditionalGeneration, extended to deal with images and layout (2D) data. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None bbox: typing.Optional[typing.Dict[str, typing.Any]] = None pixel_values: typing.Optional[torch.Tensor] = None visual_bbox: typing.Optional[typing.Dict[str, typing.Any]] = None decoder_input_ids: typing.Optional[torch.Tensor] = None decoder_attention_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_outputs: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None decoder_inputs_embeds: typing.Optional[torch.Tensor] = None decoder_head_mask: typing.Optional[torch.Tensor] = None cross_attn_head_mask: typing.Optional[torch.Tensor] = None use_cache = True output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None labels: typing.Optional[torch.Tensor] = None cache_position: typing.Optional[torch.LongTensor] = None ) → transformers.modeling_outputs.Seq2SeqLMOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.Seq2SeqLMOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (UdopConfig) and inputs.

The UdopForConditionalGeneration forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, UdopForConditionalGeneration from datasets import load_dataset

processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large")

dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True) example = dataset[0] image = example["image"] words = example["tokens"] boxes = example["bboxes"]

question = "Question answering. What is the date on the form?" encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")

predicted_ids = model.generate(**encoding) print(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]) 9/30/92

UdopEncoderModel

class transformers.UdopEncoderModel

< source >

( config: UdopConfig )

Parameters

The bare UDOP Model transformer outputting encoder’s raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None bbox: typing.Optional[typing.Dict[str, typing.Any]] = None attention_mask: typing.Optional[torch.Tensor] = None pixel_values: typing.Optional[torch.Tensor] = None visual_bbox: typing.Optional[typing.Dict[str, typing.Any]] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.models.udop.modeling_udop.BaseModelOutputWithAttentionMask or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.udop.modeling_udop.BaseModelOutputWithAttentionMask or tuple(torch.FloatTensor)

A transformers.models.udop.modeling_udop.BaseModelOutputWithAttentionMask or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (UdopConfig) and inputs.

The UdopEncoderModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, UdopEncoderModel from huggingface_hub import hf_hub_download from datasets import load_dataset

processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) model = UdopEncoderModel.from_pretrained("microsoft/udop-large")

dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True) example = dataset[0] image = example["image"] words = example["tokens"] boxes = example["bboxes"] encoding = processor(image, words, boxes=boxes, return_tensors="pt")

outputs = model(**encoding) last_hidden_states = outputs.last_hidden_state

< > Update on GitHub