BertGeneration (original) (raw)

PyTorch

Overview

The BertGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks usingEncoderDecoderModel as proposed in Leveraging Pre-trained Checkpoints for Sequence Generation Tasks by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.

The abstract from the paper is the following:

Unsupervised pretraining of large neural models has recently revolutionized Natural Language Processing. By warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT, GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation, Text Summarization, Sentence Splitting, and Sentence Fusion.

This model was contributed by patrickvonplaten. The original code can be found here.

Usage examples and tips

The model can be used in combination with the EncoderDecoderModel to leverage two pretrained BERT checkpoints for subsequent fine-tuning:

encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-large-uncased", bos_token_id=101, eos_token_id=102)

decoder = BertGenerationDecoder.from_pretrained( ... "google-bert/bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102 ... ) bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)

tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")

input_ids = tokenizer( ... "This is a long article to summarize", add_special_tokens=False, return_tensors="pt" ... ).input_ids labels = tokenizer("This is a short summary", return_tensors="pt").input_ids

loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels).loss loss.backward()

Pretrained EncoderDecoderModel are also directly available in the model hub, e.g.:

sentence_fuser = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse") tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")

input_ids = tokenizer( ... "This is the first sentence. This is the second sentence.", add_special_tokens=False, return_tensors="pt" ... ).input_ids

outputs = sentence_fuser.generate(input_ids)

print(tokenizer.decode(outputs[0]))

Tips:

BertGenerationConfig

class transformers.BertGenerationConfig

< source >

( vocab_size = 50358 hidden_size = 1024 num_hidden_layers = 24 num_attention_heads = 16 intermediate_size = 4096 hidden_act = 'gelu' hidden_dropout_prob = 0.1 attention_probs_dropout_prob = 0.1 max_position_embeddings = 512 initializer_range = 0.02 layer_norm_eps = 1e-12 pad_token_id = 0 bos_token_id = 2 eos_token_id = 1 position_embedding_type = 'absolute' use_cache = True **kwargs )

Parameters

This is the configuration class to store the configuration of a BertGenerationPreTrainedModel. It is used to instantiate a BertGeneration model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the BertGenerationgoogle/bert_for_seq_generation_L-24_bbc_encoderarchitecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Examples:

from transformers import BertGenerationConfig, BertGenerationEncoder

configuration = BertGenerationConfig()

model = BertGenerationEncoder(configuration)

configuration = model.config

BertGenerationTokenizer

class transformers.BertGenerationTokenizer

< source >

( vocab_file bos_token = '' eos_token = '' unk_token = '' pad_token = '' sep_token = '<::::>' sp_model_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None **kwargs )

Parameters

Construct a BertGeneration tokenizer. Based on SentencePiece.

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

BertGenerationEncoder

class transformers.BertGenerationEncoder

< source >

( config )

Parameters

The bare BertGeneration model transformer outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.FloatTensor]]] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None **kwargs ) → transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BertGenerationConfig) and inputs.

The BertGenerationEncoder forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

BertGenerationDecoder

class transformers.BertGenerationDecoder

< source >

( config )

Parameters

BertGeneration Model with a language modeling head on top for CLM fine-tuning.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.FloatTensor]]] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None **kwargs ) → transformers.modeling_outputs.CausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BertGenerationConfig) and inputs.

The BertGenerationDecoder forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, BertGenerationDecoder, BertGenerationConfig import torch

tokenizer = AutoTokenizer.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") config = BertGenerationConfig.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") config.is_decoder = True model = BertGenerationDecoder.from_pretrained( ... "google/bert_for_seq_generation_L-24_bbc_encoder", config=config ... )

inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt") outputs = model(**inputs)

prediction_logits = outputs.logits

< > Update on GitHub