Donut (original) (raw)

Overview

The Donut model was proposed in OCR-free Document Understanding Transformer by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. Donut consists of an image Transformer encoder and an autoregressive text Transformer decoder to perform document understanding tasks such as document image classification, form understanding and visual question answering.

The abstract from the paper is the following:

Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.

drawing Donut high-level overview. Taken from the original paper.

This model was contributed by nielsr. The original code can be foundhere.

Usage tips

Inference examples

Donut’s VisionEncoderDecoder model accepts images as input and makes use ofgenerate() to autoregressively generate text given the input image.

The DonutImageProcessor class is responsible for preprocessing the input image and [XLMRobertaTokenizer/XLMRobertaTokenizerFast] decodes the generated target tokens to the target string. TheDonutProcessor wraps DonutImageProcessor and [XLMRobertaTokenizer/XLMRobertaTokenizerFast] into a single instance to both extract the input features and decode the predicted token ids.

import re

from transformers import DonutProcessor, VisionEncoderDecoderModel from datasets import load_dataset import torch

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")

device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device)

dataset = load_dataset("hf-internal-testing/example-documents", split="test") image = dataset[1]["image"]

task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

pixel_values = processor(image, return_tensors="pt").pixel_values

outputs = model.generate( ... pixel_values.to(device), ... decoder_input_ids=decoder_input_ids.to(device), ... max_length=model.decoder.config.max_position_embeddings, ... pad_token_id=processor.tokenizer.pad_token_id, ... eos_token_id=processor.tokenizer.eos_token_id, ... use_cache=True, ... bad_words_ids=[[processor.tokenizer.unk_token_id]], ... return_dict_in_generate=True, ... )

sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(processor.token2json(sequence)) {'class': 'advertisement'}

import re

from transformers import DonutProcessor, VisionEncoderDecoderModel from datasets import load_dataset import torch

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")

device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device)

dataset = load_dataset("hf-internal-testing/example-documents", split="test") image = dataset[2]["image"]

task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

pixel_values = processor(image, return_tensors="pt").pixel_values

outputs = model.generate( ... pixel_values.to(device), ... decoder_input_ids=decoder_input_ids.to(device), ... max_length=model.decoder.config.max_position_embeddings, ... pad_token_id=processor.tokenizer.pad_token_id, ... eos_token_id=processor.tokenizer.eos_token_id, ... use_cache=True, ... bad_words_ids=[[processor.tokenizer.unk_token_id]], ... return_dict_in_generate=True, ... )

sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(processor.token2json(sequence)) {'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total': {'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}}

import re

from transformers import DonutProcessor, VisionEncoderDecoderModel from datasets import load_dataset import torch

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device)

dataset = load_dataset("hf-internal-testing/example-documents", split="test") image = dataset[0]["image"]

task_prompt = "{user_input}" question = "When is the coffee break?" prompt = task_prompt.replace("{user_input}", question) decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids

pixel_values = processor(image, return_tensors="pt").pixel_values

outputs = model.generate( ... pixel_values.to(device), ... decoder_input_ids=decoder_input_ids.to(device), ... max_length=model.decoder.config.max_position_embeddings, ... pad_token_id=processor.tokenizer.pad_token_id, ... eos_token_id=processor.tokenizer.eos_token_id, ... use_cache=True, ... bad_words_ids=[[processor.tokenizer.unk_token_id]], ... return_dict_in_generate=True, ... )

sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(processor.token2json(sequence)) {'question': 'When is the coffee break?', 'answer': '11-14 to 11:39 a.m.'}

See the model hub to look for Donut checkpoints.

Training

We refer to the tutorial notebooks.

DonutSwinConfig

class transformers.DonutSwinConfig

< source >

( image_size = 224 patch_size = 4 num_channels = 3 embed_dim = 96 depths = [2, 2, 6, 2] num_heads = [3, 6, 12, 24] window_size = 7 mlp_ratio = 4.0 qkv_bias = True hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 drop_path_rate = 0.1 hidden_act = 'gelu' use_absolute_embeddings = False initializer_range = 0.02 layer_norm_eps = 1e-05 **kwargs )

Parameters

This is the configuration class to store the configuration of a DonutSwinModel. It is used to instantiate a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Donutnaver-clova-ix/donut-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import DonutSwinConfig, DonutSwinModel

configuration = DonutSwinConfig()

model = DonutSwinModel(configuration)

configuration = model.config

DonutImageProcessor

class transformers.DonutImageProcessor

< source >

( do_resize: bool = True size: typing.Dict[str, int] = None resample: Resampling = <Resampling.BILINEAR: 2> do_thumbnail: bool = True do_align_long_axis: bool = False do_pad: bool = True do_rescale: bool = True rescale_factor: typing.Union[int, float] = 0.00392156862745098 do_normalize: bool = True image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None **kwargs )

Parameters

Constructs a Donut image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] do_resize: typing.Optional[bool] = None size: typing.Dict[str, int] = None resample: Resampling = None do_thumbnail: typing.Optional[bool] = None do_align_long_axis: typing.Optional[bool] = None do_pad: typing.Optional[bool] = None random_padding: bool = False do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_normalize: typing.Optional[bool] = None image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: typing.Optional[transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[transformers.image_utils.ChannelDimension, str, NoneType] = None )

Parameters

Preprocess an image or batch of images.

DonutFeatureExtractor

Preprocess an image or a batch of images.

DonutProcessor

class transformers.DonutProcessor

< source >

( image_processor = None tokenizer = None **kwargs )

Parameters

Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single processor.

DonutProcessor offers all the functionalities of DonutImageProcessor and [XLMRobertaTokenizer/XLMRobertaTokenizerFast]. See the call() anddecode() for more information.

__call__

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] = None text: typing.Union[str, typing.List[str], NoneType] = None audio = None videos = None **kwargs: typing_extensions.Unpack[transformers.models.donut.processing_donut.DonutProcessorKwargs] )

When used in normal mode, this method forwards all its arguments to AutoImageProcessor’s__call__() and returns its output. If used in the contextas_target_processor() this method forwards all its arguments to DonutTokenizer’s~DonutTokenizer.__call__. Please refer to the docstring of the above two methods for more information.

from_pretrained

< source >

( pretrained_model_name_or_path: typing.Union[str, os.PathLike] cache_dir: typing.Union[str, os.PathLike, NoneType] = None force_download: bool = False local_files_only: bool = False token: typing.Union[str, bool, NoneType] = None revision: str = 'main' **kwargs )

Parameters

Instantiate a processor associated with a pretrained model.

This class method is simply calling the feature extractorfrom_pretrained(), image processorImageProcessingMixin and the tokenizer~tokenization_utils_base.PreTrainedTokenizer.from_pretrained methods. Please refer to the docstrings of the methods above for more information.

save_pretrained

< source >

( save_directory push_to_hub: bool = False **kwargs )

Parameters

Saves the attributes of this processor (feature extractor, tokenizer…) in the specified directory so that it can be reloaded using the from_pretrained() method.

This class method is simply calling save_pretrained() andsave_pretrained(). Please refer to the docstrings of the methods above for more information.

This method forwards all its arguments to DonutTokenizer’s batch_decode(). Please refer to the docstring of this method for more information.

This method forwards all its arguments to DonutTokenizer’s decode(). Please refer to the docstring of this method for more information.

DonutSwinModel

class transformers.DonutSwinModel

< source >

( config add_pooling_layer = True use_mask_token = False )

Parameters

The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: typing.Optional[torch.FloatTensor] = None bool_masked_pos: typing.Optional[torch.BoolTensor] = None head_mask: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None interpolate_pos_encoding: bool = False return_dict: typing.Optional[bool] = None ) → transformers.models.donut.modeling_donut_swin.DonutSwinModelOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.donut.modeling_donut_swin.DonutSwinModelOutput or tuple(torch.FloatTensor)

A transformers.models.donut.modeling_donut_swin.DonutSwinModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (DonutSwinConfig) and inputs.

The DonutSwinModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoImageProcessor, DonutSwinModel import torch from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) image = dataset["test"]["image"][0]

image_processor = AutoImageProcessor.from_pretrained("https://huggingface.co/naver-clova-ix/donut-base") model = DonutSwinModel.from_pretrained("https://huggingface.co/naver-clova-ix/donut-base")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state list(last_hidden_states.shape) [1, 49, 768]

< > Update on GitHub