LayoutLMV2 (original) (raw)

PyTorch

Overview

The LayoutLMV2 model was proposed in LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. LayoutLMV2 improves LayoutLM to obtain state-of-the-art results across several document image understanding benchmarks:

The abstract from the paper is the following:

Pre-training of text and layout has proved effective in a variety of visually-rich document understanding tasks due to its effective model architecture and the advantage of large-scale unlabeled scanned/digital-born documents. In this paper, we present LayoutLMv2 by pre-training text, layout and image in a multi-modal framework, where new model architectures and pre-training tasks are leveraged. Specifically, LayoutLMv2 not only uses the existing masked visual-language modeling task but also the new text-image alignment and text-image matching tasks in the pre-training stage, where cross-modality interaction is better learned. Meanwhile, it also integrates a spatial-aware self-attention mechanism into the Transformer architecture, so that the model can fully understand the relative positional relationship among different text blocks. Experiment results show that LayoutLMv2 outperforms strong baselines and achieves new state-of-the-art results on a wide variety of downstream visually-rich document understanding tasks, including FUNSD (0.7895 -> 0.8420), CORD (0.9493 -> 0.9601), SROIE (0.9524 -> 0.9781), Kleister-NDA (0.834 -> 0.852), RVL-CDIP (0.9443 -> 0.9564), and DocVQA (0.7295 -> 0.8672). The pre-trained LayoutLMv2 model is publicly available at this https URL.

LayoutLMv2 depends on detectron2, torchvision and tesseract. Run the following to install them:

python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' python -m pip install torchvision tesseract

(If you are developing for LayoutLMv2, note that passing the doctests also requires the installation of these packages.)

Usage tips

def normalize_bbox(bbox, width, height): return [ int(1000 * (bbox[0] / width)), int(1000 * (bbox[1] / height)), int(1000 * (bbox[2] / width)), int(1000 * (bbox[3] / height)), ]

Here, width and height correspond to the width and height of the original document in which the token occurs (before resizing the image). Those can be obtained using the Python Image Library (PIL) library for example, as follows:

from PIL import Image

image = Image.open( "name_of_your_document - can be a png, jpg, etc. of your documents (PDFs must be converted to images)." )

width, height = image.size

However, this model includes a brand new LayoutLMv2Processor which can be used to directly prepare data for the model (including applying OCR under the hood). More information can be found in the “Usage” section below.

In addition, there’s LayoutXLM, which is a multilingual version of LayoutLMv2. More information can be found onLayoutXLM’s documentation page.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with LayoutLMv2. If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

Usage: LayoutLMv2Processor

The easiest way to prepare data for the model is to use LayoutLMv2Processor, which internally combines a image processor (LayoutLMv2ImageProcessor) and a tokenizer (LayoutLMv2Tokenizer or LayoutLMv2TokenizerFast). The image processor handles the image modality, while the tokenizer handles the text modality. A processor combines both, which is ideal for a multi-modal model like LayoutLMv2. Note that you can still use both separately, if you only want to handle one modality.

from transformers import LayoutLMv2ImageProcessor, LayoutLMv2TokenizerFast, LayoutLMv2Processor

image_processor = LayoutLMv2ImageProcessor()
tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased") processor = LayoutLMv2Processor(image_processor, tokenizer)

In short, one can provide a document image (and possibly additional data) to LayoutLMv2Processor, and it will create the inputs expected by the model. Internally, the processor first usesLayoutLMv2ImageProcessor to apply OCR on the image to get a list of words and normalized bounding boxes, as well to resize the image to a given size in order to get the image input. The words and normalized bounding boxes are then provided to LayoutLMv2Tokenizer orLayoutLMv2TokenizerFast, which converts them to token-level input_ids,attention_mask, token_type_ids, bbox. Optionally, one can provide word labels to the processor, which are turned into token-level labels.

LayoutLMv2Processor uses PyTesseract, a Python wrapper around Google’s Tesseract OCR engine, under the hood. Note that you can still use your own OCR engine of choice, and provide the words and normalized boxes yourself. This requires initializingLayoutLMv2ImageProcessor with apply_ocr set to False.

In total, there are 5 use cases that are supported by the processor. Below, we list them all. Note that each of these use cases work for both batched and non-batched inputs (we illustrate them for non-batched inputs).

Use case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True

This is the simplest case, in which the processor (actually the image processor) will perform OCR on the image to get the words and normalized bounding boxes.

from transformers import LayoutLMv2Processor from PIL import Image

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")

image = Image.open( "name_of_your_document - can be a png, jpg, etc. of your documents (PDFs must be converted to images)." ).convert("RGB") encoding = processor( image, return_tensors="pt" )
print(encoding.keys())

Use case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False

In case one wants to do OCR themselves, one can initialize the image processor with apply_ocr set toFalse. In that case, one should provide the words and corresponding (normalized) bounding boxes themselves to the processor.

from transformers import LayoutLMv2Processor from PIL import Image

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")

image = Image.open( "name_of_your_document - can be a png, jpg, etc. of your documents (PDFs must be converted to images)." ).convert("RGB") words = ["hello", "world"] boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
encoding = processor(image, words, boxes=boxes, return_tensors="pt") print(encoding.keys())

Use case 3: token classification (training), apply_ocr=False

For token classification tasks (such as FUNSD, CORD, SROIE, Kleister-NDA), one can also provide the corresponding word labels in order to train a model. The processor will then convert these into token-level labels. By default, it will only label the first wordpiece of a word, and label the remaining wordpieces with -100, which is theignore_index of PyTorch’s CrossEntropyLoss. In case you want all wordpieces of a word to be labeled, you can initialize the tokenizer with only_label_first_subword set to False.

from transformers import LayoutLMv2Processor from PIL import Image

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")

image = Image.open( "name_of_your_document - can be a png, jpg, etc. of your documents (PDFs must be converted to images)." ).convert("RGB") words = ["hello", "world"] boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
word_labels = [1, 2] encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt") print(encoding.keys())

Use case 4: visual question answering (inference), apply_ocr=True

For visual question answering tasks (such as DocVQA), you can provide a question to the processor. By default, the processor will apply OCR on the image, and create [CLS] question tokens [SEP] word tokens [SEP].

from transformers import LayoutLMv2Processor from PIL import Image

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")

image = Image.open( "name_of_your_document - can be a png, jpg, etc. of your documents (PDFs must be converted to images)." ).convert("RGB") question = "What's his name?" encoding = processor(image, question, return_tensors="pt") print(encoding.keys())

Use case 5: visual question answering (inference), apply_ocr=False

For visual question answering tasks (such as DocVQA), you can provide a question to the processor. If you want to perform OCR yourself, you can provide your own words and (normalized) bounding boxes to the processor.

from transformers import LayoutLMv2Processor from PIL import Image

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")

image = Image.open( "name_of_your_document - can be a png, jpg, etc. of your documents (PDFs must be converted to images)." ).convert("RGB") question = "What's his name?" words = ["hello", "world"] boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
encoding = processor(image, question, words, boxes=boxes, return_tensors="pt") print(encoding.keys())

LayoutLMv2Config

class transformers.LayoutLMv2Config

< source >

( vocab_size = 30522 hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout_prob = 0.1 attention_probs_dropout_prob = 0.1 max_position_embeddings = 512 type_vocab_size = 2 initializer_range = 0.02 layer_norm_eps = 1e-12 pad_token_id = 0 max_2d_position_embeddings = 1024 max_rel_pos = 128 rel_pos_bins = 32 fast_qkv = True max_rel_2d_pos = 256 rel_2d_pos_bins = 64 convert_sync_batchnorm = True image_feature_pool_shape = [7, 7, 256] coordinate_size = 128 shape_size = 128 has_relative_attention_bias = True has_spatial_attention_bias = True has_visual_segment_embedding = False detectron2_config_args = None **kwargs )

Parameters

This is the configuration class to store the configuration of a LayoutLMv2Model. It is used to instantiate an LayoutLMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the LayoutLMv2microsoft/layoutlmv2-base-uncased architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import LayoutLMv2Config, LayoutLMv2Model

configuration = LayoutLMv2Config()

model = LayoutLMv2Model(configuration)

configuration = model.config

LayoutLMv2FeatureExtractor

Preprocess an image or a batch of images.

LayoutLMv2ImageProcessor

class transformers.LayoutLMv2ImageProcessor

< source >

( do_resize: bool = True size: typing.Dict[str, int] = None resample: Resampling = <Resampling.BILINEAR: 2> apply_ocr: bool = True ocr_lang: typing.Optional[str] = None tesseract_config: typing.Optional[str] = '' **kwargs )

Parameters

Constructs a LayoutLMv2 image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] do_resize: typing.Optional[bool] = None size: typing.Dict[str, int] = None resample: Resampling = None apply_ocr: typing.Optional[bool] = None ocr_lang: typing.Optional[str] = None tesseract_config: typing.Optional[str] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: ChannelDimension = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[transformers.image_utils.ChannelDimension, str, NoneType] = None )

Parameters

Preprocess an image or batch of images.

LayoutLMv2Tokenizer

class transformers.LayoutLMv2Tokenizer

< source >

( vocab_file do_lower_case = True do_basic_tokenize = True never_split = None unk_token = '[UNK]' sep_token = '[SEP]' pad_token = '[PAD]' cls_token = '[CLS]' mask_token = '[MASK]' cls_token_box = [0, 0, 0, 0] sep_token_box = [1000, 1000, 1000, 1000] pad_token_box = [0, 0, 0, 0] pad_token_label = -100 only_label_first_subword = True tokenize_chinese_chars = True strip_accents = None model_max_length: int = 512 additional_special_tokens: typing.Optional[typing.List[str]] = None **kwargs )

Construct a LayoutLMv2 tokenizer. Based on WordPiece. LayoutLMv2Tokenizer can be used to turn words, word-level bounding boxes and optional word labels to token-level input_ids, attention_mask, token_type_ids, bbox, and optional labels (for token classification).

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

LayoutLMv2Tokenizer runs end-to-end tokenization: punctuation splitting and wordpiece. It also turns the word-level bounding boxes into token-level bounding boxes.

__call__

< source >

( text: typing.Union[str, typing.List[str], typing.List[typing.List[str]]] text_pair: typing.Union[typing.List[str], typing.List[typing.List[str]], NoneType] = None boxes: typing.Union[typing.List[typing.List[int]], typing.List[typing.List[typing.List[int]]]] = None word_labels: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = None max_length: typing.Optional[int] = None stride: int = 0 pad_to_multiple_of: typing.Optional[int] = None padding_side: typing.Optional[str] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None return_token_type_ids: typing.Optional[bool] = None return_attention_mask: typing.Optional[bool] = None return_overflowing_tokens: bool = False return_special_tokens_mask: bool = False return_offsets_mapping: bool = False return_length: bool = False verbose: bool = True **kwargs ) → BatchEncoding

Parameters

A BatchEncoding with the following fields:

Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of sequences with word-level normalized bounding boxes and optional labels.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

LayoutLMv2TokenizerFast

class transformers.LayoutLMv2TokenizerFast

< source >

( vocab_file = None tokenizer_file = None do_lower_case = True unk_token = '[UNK]' sep_token = '[SEP]' pad_token = '[PAD]' cls_token = '[CLS]' mask_token = '[MASK]' cls_token_box = [0, 0, 0, 0] sep_token_box = [1000, 1000, 1000, 1000] pad_token_box = [0, 0, 0, 0] pad_token_label = -100 only_label_first_subword = True tokenize_chinese_chars = True strip_accents = None **kwargs )

Parameters

Construct a “fast” LayoutLMv2 tokenizer (backed by HuggingFace’s tokenizers library). Based on WordPiece.

This tokenizer inherits from PreTrainedTokenizerFast which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

__call__

< source >

( text: typing.Union[str, typing.List[str], typing.List[typing.List[str]]] text_pair: typing.Union[typing.List[str], typing.List[typing.List[str]], NoneType] = None boxes: typing.Union[typing.List[typing.List[int]], typing.List[typing.List[typing.List[int]]]] = None word_labels: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = None max_length: typing.Optional[int] = None stride: int = 0 pad_to_multiple_of: typing.Optional[int] = None padding_side: typing.Optional[str] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None return_token_type_ids: typing.Optional[bool] = None return_attention_mask: typing.Optional[bool] = None return_overflowing_tokens: bool = False return_special_tokens_mask: bool = False return_offsets_mapping: bool = False return_length: bool = False verbose: bool = True **kwargs ) → BatchEncoding

Parameters

A BatchEncoding with the following fields:

Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of sequences with word-level normalized bounding boxes and optional labels.

LayoutLMv2Processor

class transformers.LayoutLMv2Processor

< source >

( image_processor = None tokenizer = None **kwargs )

Parameters

Constructs a LayoutLMv2 processor which combines a LayoutLMv2 image processor and a LayoutLMv2 tokenizer into a single processor.

LayoutLMv2Processor offers all the functionalities you need to prepare data for the model.

It first uses LayoutLMv2ImageProcessor to resize document images to a fixed size, and optionally applies OCR to get words and normalized bounding boxes. These are then provided to LayoutLMv2Tokenizer orLayoutLMv2TokenizerFast, which turns the words and bounding boxes into token-level input_ids,attention_mask, token_type_ids, bbox. Optionally, one can provide integer word_labels, which are turned into token-level labels for token classification tasks (such as FUNSD, CORD).

__call__

< source >

( images text: typing.Union[str, typing.List[str], typing.List[typing.List[str]]] = None text_pair: typing.Union[typing.List[str], typing.List[typing.List[str]], NoneType] = None boxes: typing.Union[typing.List[typing.List[int]], typing.List[typing.List[typing.List[int]]]] = None word_labels: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None add_special_tokens: bool = True padding: typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = False truncation: typing.Union[bool, str, transformers.tokenization_utils_base.TruncationStrategy] = False max_length: typing.Optional[int] = None stride: int = 0 pad_to_multiple_of: typing.Optional[int] = None return_token_type_ids: typing.Optional[bool] = None return_attention_mask: typing.Optional[bool] = None return_overflowing_tokens: bool = False return_special_tokens_mask: bool = False return_offsets_mapping: bool = False return_length: bool = False verbose: bool = True return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None **kwargs )

This method first forwards the images argument to call(). In caseLayoutLMv2ImageProcessor was initialized with apply_ocr set to True, it passes the obtained words and bounding boxes along with the additional arguments to call() and returns the output, together with resized images. In case LayoutLMv2ImageProcessor was initialized with apply_ocr set toFalse, it passes the words (text/text_pair`) and `boxes` specified by the user along with the additional arguments to [__call__()](/docs/transformers/v4.51.3/en/model_doc/layoutlmv2#transformers.LayoutLMv2Tokenizer.__call__) and returns the output, together with resized `images.

Please refer to the docstring of the above two methods for more information.

LayoutLMv2Model

class transformers.LayoutLMv2Model

< source >

( config )

Parameters

The bare LayoutLMv2 Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None bbox: typing.Optional[torch.LongTensor] = None image: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None token_type_ids: typing.Optional[torch.LongTensor] = None position_ids: typing.Optional[torch.LongTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LayoutLMv2Config) and inputs.

The LayoutLMv2Model forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoProcessor, LayoutLMv2Model, set_seed from PIL import Image import torch from datasets import load_dataset

set_seed(0)

processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased") model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")

dataset = load_dataset("hf-internal-testing/fixtures_docvqa", trust_remote_code=True) image_path = dataset["test"][0]["file"] image = Image.open(image_path).convert("RGB")

encoding = processor(image, return_tensors="pt")

outputs = model(**encoding) last_hidden_states = outputs.last_hidden_state

last_hidden_states.shape torch.Size([1, 342, 768])

LayoutLMv2ForSequenceClassification

class transformers.LayoutLMv2ForSequenceClassification

< source >

( config )

Parameters

LayoutLMv2 Model with a sequence classification head on top (a linear layer on top of the concatenation of the final hidden state of the [CLS] token, average-pooled initial visual embeddings and average-pooled final visual embeddings, e.g. for document image classification tasks such as theRVL-CDIP dataset.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None bbox: typing.Optional[torch.LongTensor] = None image: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None token_type_ids: typing.Optional[torch.LongTensor] = None position_ids: typing.Optional[torch.LongTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.SequenceClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.SequenceClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LayoutLMv2Config) and inputs.

The LayoutLMv2ForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, LayoutLMv2ForSequenceClassification, set_seed from PIL import Image import torch from datasets import load_dataset

set_seed(0)

dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True, trust_remote_code=True) data = next(iter(dataset)) image = data["image"].convert("RGB")

processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased") model = LayoutLMv2ForSequenceClassification.from_pretrained( ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes ... )

encoding = processor(image, return_tensors="pt") sequence_label = torch.tensor([data["label"]])

outputs = model(**encoding, labels=sequence_label)

loss, logits = outputs.loss, outputs.logits predicted_idx = logits.argmax(dim=-1).item() predicted_answer = dataset.info.features["label"].names[4] predicted_idx, predicted_answer
(7, 'advertisement')

LayoutLMv2ForTokenClassification

class transformers.LayoutLMv2ForTokenClassification

< source >

( config )

Parameters

LayoutLMv2 Model with a token classification head on top (a linear layer on top of the text part of the hidden states) e.g. for sequence labeling (information extraction) tasks such asFUNSD, SROIE,CORD and Kleister-NDA.

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None bbox: typing.Optional[torch.LongTensor] = None image: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None token_type_ids: typing.Optional[torch.LongTensor] = None position_ids: typing.Optional[torch.LongTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.TokenClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.TokenClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LayoutLMv2Config) and inputs.

The LayoutLMv2ForTokenClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoProcessor, LayoutLMv2ForTokenClassification, set_seed from PIL import Image from datasets import load_dataset

set_seed(0)

datasets = load_dataset("nielsr/funsd", split="test", trust_remote_code=True) labels = datasets.features["ner_tags"].feature.names id2label = {v: k for v, k in enumerate(labels)}

processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr") model = LayoutLMv2ForTokenClassification.from_pretrained( ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels) ... )

data = datasets[0] image = Image.open(data["image_path"]).convert("RGB") words = data["words"] boxes = data["bboxes"]
word_labels = data["ner_tags"] encoding = processor( ... image, ... words, ... boxes=boxes, ... word_labels=word_labels, ... padding="max_length", ... truncation=True, ... return_tensors="pt", ... )

outputs = model(**encoding) logits, loss = outputs.logits, outputs.loss

predicted_token_class_ids = logits.argmax(-1) predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]] predicted_tokens_classes[:5]
['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION']

LayoutLMv2ForQuestionAnswering

class transformers.LayoutLMv2ForQuestionAnswering

< source >

( config has_visual_segment_embedding = True )

Parameters

LayoutLMv2 Model with a span classification head on top for extractive question-answering tasks such asDocVQA (a linear layer on top of the text part of the hidden-states output to compute span start logits and span end logits).

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None bbox: typing.Optional[torch.LongTensor] = None image: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None token_type_ids: typing.Optional[torch.LongTensor] = None position_ids: typing.Optional[torch.LongTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None start_positions: typing.Optional[torch.LongTensor] = None end_positions: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.QuestionAnsweringModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.QuestionAnsweringModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LayoutLMv2Config) and inputs.

The LayoutLMv2ForQuestionAnswering forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).

from transformers import AutoProcessor, LayoutLMv2ForQuestionAnswering, set_seed import torch from PIL import Image from datasets import load_dataset

set_seed(0) processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased") model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")

dataset = load_dataset("hf-internal-testing/fixtures_docvqa", trust_remote_code=True) image_path = dataset["test"][0]["file"] image = Image.open(image_path).convert("RGB") question = "When is coffee break?" encoding = processor(image, question, return_tensors="pt")

outputs = model(**encoding) predicted_start_idx = outputs.start_logits.argmax(-1).item() predicted_end_idx = outputs.end_logits.argmax(-1).item() predicted_start_idx, predicted_end_idx (30, 191)

predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] predicted_answer = processor.tokenizer.decode(predicted_answer_tokens) predicted_answer
'44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from'

target_start_index = torch.tensor([7]) target_end_index = torch.tensor([14]) outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index) predicted_answer_span_start = outputs.start_logits.argmax(-1).item() predicted_answer_span_end = outputs.end_logits.argmax(-1).item() predicted_answer_span_start, predicted_answer_span_end (30, 191)

< > Update on GitHub