BART (original) (raw)

FlaxBartModel

class transformers.FlaxBartModel

< source >

( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

The bare Bart Model transformer outputting raw hidden-states without any specific head on top. This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BartConfig) and inputs.

The FlaxBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxBartModel

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartModel.from_pretrained("facebook/bart-base")

inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state

FlaxBartForConditionalGeneration

class transformers.FlaxBartForConditionalGeneration

< source >

( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

The BART Model with a language modeling head. Can be used for summarization. This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BartConfig) and inputs.

The FlaxBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Summarization example:

from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")

summary_ids = model.generate(inputs["input_ids"]).sequences print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))

Mask filling example:

import jax from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

TXT = "My friends are but they eat too many carbs." input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"]

logits = model(input_ids).logits masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() probs = jax.nn.softmax(logits[0, masked_index], axis=0) values, predictions = jax.lax.top_k(probs, k=1)

tokenizer.decode(predictions).split()

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) logits = outputs.logits

FlaxBartForSequenceClassification

class transformers.FlaxBartForSequenceClassification

< source >

( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BartConfig) and inputs.

The FlaxBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxBartForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartForSequenceClassification.from_pretrained("facebook/bart-base")

inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")

outputs = model(**inputs) logits = outputs.logits

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state

FlaxBartForQuestionAnswering

class transformers.FlaxBartForQuestionAnswering

< source >

( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits and span end logits).

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BartConfig) and inputs.

The FlaxBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxBartForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartForQuestionAnswering.from_pretrained("facebook/bart-base")

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" inputs = tokenizer(question, text, return_tensors="jax")

outputs = model(**inputs) start_scores = outputs.start_logits end_scores = outputs.end_logits

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>) and inputs.

Example:

import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state

FlaxBartForCausalLM

class transformers.FlaxBartForCausalLM

< source >

( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for autoregressive tasks.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None encoder_hidden_states: typing.Optional[jax.Array] = None encoder_attention_mask: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None past_key_values: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (BartConfig) and inputs.

The FlaxBartDecoderPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxBartForCausalLM

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartForCausalLM.from_pretrained("facebook/bart-base")

inputs = tokenizer("Hello, my dog is cute", return_tensors="np") outputs = model(**inputs)

next_token_logits = outputs.logits[:, -1]