Splinter (original) (raw)

PyTorch

Overview

The Splinter model was proposed in Few-Shot Question Answering by Pretraining Span Selection by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. Splinter is an encoder-only transformer (similar to BERT) pretrained using the recurring span selection task on a large corpus comprising Wikipedia and the Toronto Book Corpus.

The abstract from the paper is the following:

In several question answering benchmarks, pretrained models have reached human parity through fine-tuning on an order of 100,000 annotated questions and answers. We explore the more realistic few-shot setting, where only a few hundred training examples are available, and observe that standard models perform poorly, highlighting the discrepancy between current pretraining objectives and question answering. We propose a new pretraining scheme tailored for question answering: recurring span selection. Given a passage with multiple sets of recurring spans, we mask in each set all recurring spans but one, and ask the model to select the correct span in the passage for each masked span. Masked spans are replaced with a special token, viewed as a question representation, that is later used during fine-tuning to select the answer span. The resulting model obtains surprisingly good results on multiple benchmarks (e.g., 72.7 F1 on SQuAD with only 128 training examples), while maintaining competitive performance in the high-resource setting.

This model was contributed by yuvalkirstain and oriram. The original code can be found here.

Usage tips

Resources

SplinterConfig

class transformers.SplinterConfig

< source >

( vocab_size = 30522 hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout_prob = 0.1 attention_probs_dropout_prob = 0.1 max_position_embeddings = 512 type_vocab_size = 2 initializer_range = 0.02 layer_norm_eps = 1e-12 use_cache = True pad_token_id = 0 question_token_id = 104 **kwargs )

Parameters

This is the configuration class to store the configuration of a SplinterModel. It is used to instantiate an Splinter model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Splintertau/splinter-base architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import SplinterModel, SplinterConfig

configuration = SplinterConfig()

model = SplinterModel(configuration)

configuration = model.config

SplinterTokenizer

class transformers.SplinterTokenizer

< source >

( vocab_file do_lower_case = True do_basic_tokenize = True never_split = None unk_token = '[UNK]' sep_token = '[SEP]' pad_token = '[PAD]' cls_token = '[CLS]' mask_token = '[MASK]' question_token = '[QUESTION]' tokenize_chinese_chars = True strip_accents = None **kwargs )

Parameters

Construct a Splinter tokenizer. Based on WordPiece.

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

build_inputs_with_special_tokens

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of input IDs with the appropriate special tokens.

Build model inputs from a pair of sequence for question answering tasks by concatenating and adding special tokens. A Splinter sequence has the following format:

get_special_tokens_mask

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None already_has_special_tokens: bool = False ) → List[int]

Parameters

A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.

Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer prepare_for_model method.

create_token_type_ids_from_sequences

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

The token type ids.

Create the token type IDs corresponding to the sequences passed. What are token type IDs?

Should be overridden in a subclass if the model has a special way of building those.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

SplinterTokenizerFast

class transformers.SplinterTokenizerFast

< source >

( vocab_file = None tokenizer_file = None do_lower_case = True unk_token = '[UNK]' sep_token = '[SEP]' pad_token = '[PAD]' cls_token = '[CLS]' mask_token = '[MASK]' question_token = '[QUESTION]' tokenize_chinese_chars = True strip_accents = None **kwargs )

Parameters

Construct a “fast” Splinter tokenizer (backed by HuggingFace’s tokenizers library). Based on WordPiece.

This tokenizer inherits from PreTrainedTokenizerFast which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

build_inputs_with_special_tokens

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of input IDs with the appropriate special tokens.

Build model inputs from a pair of sequence for question answering tasks by concatenating and adding special tokens. A Splinter sequence has the following format:

SplinterModel

class transformers.SplinterModel

< source >

( config )

Parameters

The bare Splinter Model outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.List[torch.FloatTensor]] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SplinterConfig) and inputs.

The SplinterModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

SplinterForQuestionAnswering

class transformers.SplinterForQuestionAnswering

< source >

( config )

Parameters

The Splinter transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits and span end logits).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None start_positions: typing.Optional[torch.LongTensor] = None end_positions: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None question_positions: typing.Optional[torch.LongTensor] = None ) → transformers.modeling_outputs.QuestionAnsweringModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.QuestionAnsweringModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SplinterConfig) and inputs.

The SplinterForQuestionAnswering forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, SplinterForQuestionAnswering import torch

tokenizer = AutoTokenizer.from_pretrained("tau/splinter-base") model = SplinterForQuestionAnswering.from_pretrained("tau/splinter-base")

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"

inputs = tokenizer(question, text, return_tensors="pt") with torch.no_grad(): ... outputs = model(**inputs)

answer_start_index = outputs.start_logits.argmax() answer_end_index = outputs.end_logits.argmax()

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] tokenizer.decode(predict_answer_tokens, skip_special_tokens=True) ...

target_start_index = torch.tensor([14]) target_end_index = torch.tensor([15])

outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) loss = outputs.loss round(loss.item(), 2) ...

SplinterForPreTraining

class transformers.SplinterForPreTraining

< source >

( config )

Parameters

Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans instead.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None start_positions: typing.Optional[torch.LongTensor] = None end_positions: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None question_positions: typing.Optional[torch.LongTensor] = None ) → transformers.models.splinter.modeling_splinter.SplinterForPreTrainingOutput or tuple(torch.FloatTensor)

Parameters

Returns

transformers.models.splinter.modeling_splinter.SplinterForPreTrainingOutput or tuple(torch.FloatTensor)

A transformers.models.splinter.modeling_splinter.SplinterForPreTrainingOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (SplinterConfig) and inputs.

The SplinterForPreTraining forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

< > Update on GitHub