BART (original) (raw)
FlaxBartModel
class transformers.FlaxBartModel
( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )
Parameters
- config (BartConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
, optional, defaults tojax.numpy.float32
) — The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the givendtype
.
Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.
If you wish to change the dtype of the model parameters, see to_fp16() andto_bf16().
The bare Bart Model transformer outputting raw hidden-states without any specific head on top. This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
__call__
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (BartConfig) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the decoder of the model.
Ifpast_key_values
is used only the last hidden-state of the sequences of shape(batch_size, 1, hidden_size)
is output. - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and 2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - decoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
, optional) — Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.
The FlaxBartPreTrainedModel
forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Example:
from transformers import AutoTokenizer, FlaxBartModel
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartModel.from_pretrained("facebook/bart-base")
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
encode
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decode
( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)
Parameters
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - encoder_outputs (
tuple(tuple(jnp.ndarray)
) — Tuple consists of (last_hidden_state
, optional:hidden_states
, optional:attentions
)last_hidden_state
of shape(batch_size, sequence_length, hidden_size)
, optional) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - past_key_values (
Dict[str, np.ndarray]
, optional, returned byinit_cache
or when passing previouspast_key_values
) — Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape [batch_size, max_length]. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.
Ifpast_key_values
is used only the last hidden-state of the sequences of shape(batch_size, 1, hidden_size)
is output. - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and optionally ifconfig.is_encoder_decoder=True
2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally ifconfig.is_encoder_decoder=True
in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
andconfig.add_cross_attention=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.
Example:
import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state
FlaxBartForConditionalGeneration
class transformers.FlaxBartForConditionalGeneration
( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )
Parameters
- config (BartConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
, optional, defaults tojax.numpy.float32
) — The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the givendtype
.
Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.
If you wish to change the dtype of the model parameters, see to_fp16() andto_bf16().
The BART Model with a language modeling head. Can be used for summarization. This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
__call__
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (BartConfig) and inputs.
- logits (
jnp.ndarray
of shape(batch_size, sequence_length, config.vocab_size)
) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and 2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - decoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
, optional) — Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.
The FlaxBartPreTrainedModel
forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Summarization example:
from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")
summary_ids = model.generate(inputs["input_ids"]).sequences print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
Mask filling example:
import jax from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
TXT = "My friends are but they eat too many carbs." input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"]
logits = model(input_ids).logits masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() probs = jax.nn.softmax(logits[0, masked_index], axis=0) values, predictions = jax.lax.top_k(probs, k=1)
tokenizer.decode(predictions).split()
encode
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decode
( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)
Parameters
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - encoder_outputs (
tuple(tuple(jnp.ndarray)
) — Tuple consists of (last_hidden_state
, optional:hidden_states
, optional:attentions
)last_hidden_state
of shape(batch_size, sequence_length, hidden_size)
, optional) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - past_key_values (
Dict[str, np.ndarray]
, optional, returned byinit_cache
or when passing previouspast_key_values
) — Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape [batch_size, max_length]. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- logits (
jnp.ndarray
of shape(batch_size, sequence_length, config.vocab_size)
) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads. - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple ofjnp.ndarray
tuples of lengthconfig.n_layers
, with each tuple containing the cached key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. Only relevant ifconfig.is_decoder = True
.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding.
Example:
import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs) logits = outputs.logits
FlaxBartForSequenceClassification
class transformers.FlaxBartForSequenceClassification
( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )
Parameters
- config (BartConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
, optional, defaults tojax.numpy.float32
) — The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the givendtype
.
Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.
If you wish to change the dtype of the model parameters, see to_fp16() andto_bf16().
Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks.
This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
__call__
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (BartConfig) and inputs.
- logits (
jnp.ndarray
of shape(batch_size, config.num_labels)
) — Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and 2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - decoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
, optional) — Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.
The FlaxBartPreTrainedModel
forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Example:
from transformers import AutoTokenizer, FlaxBartForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartForSequenceClassification.from_pretrained("facebook/bart-base")
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
outputs = model(**inputs) logits = outputs.logits
encode
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decode
( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)
Parameters
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - encoder_outputs (
tuple(tuple(jnp.ndarray)
) — Tuple consists of (last_hidden_state
, optional:hidden_states
, optional:attentions
)last_hidden_state
of shape(batch_size, sequence_length, hidden_size)
, optional) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - past_key_values (
Dict[str, np.ndarray]
, optional, returned byinit_cache
or when passing previouspast_key_values
) — Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape [batch_size, max_length]. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.
Ifpast_key_values
is used only the last hidden-state of the sequences of shape(batch_size, 1, hidden_size)
is output. - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and optionally ifconfig.is_encoder_decoder=True
2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally ifconfig.is_encoder_decoder=True
in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
andconfig.add_cross_attention=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.
Example:
import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state
FlaxBartForQuestionAnswering
class transformers.FlaxBartForQuestionAnswering
( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )
Parameters
- config (BartConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
, optional, defaults tojax.numpy.float32
) — The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the givendtype
.
Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.
If you wish to change the dtype of the model parameters, see to_fp16() andto_bf16().
BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits
and span end logits
).
This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
__call__
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (BartConfig) and inputs.
- start_logits (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Span-start scores (before SoftMax). - end_logits (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Span-end scores (before SoftMax). - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and 2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - decoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
, optional) — Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.
The FlaxBartPreTrainedModel
forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Example:
from transformers import AutoTokenizer, FlaxBartForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartForQuestionAnswering.from_pretrained("facebook/bart-base")
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" inputs = tokenizer(question, text, return_tensors="jax")
outputs = model(**inputs) start_scores = outputs.start_logits end_scores = outputs.end_logits
encode
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)
Parameters
- input_ids (
jnp.ndarray
of shape(batch_size, sequence_length)
) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are input IDs? - attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decode
( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: dict = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)
Parameters
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - encoder_outputs (
tuple(tuple(jnp.ndarray)
) — Tuple consists of (last_hidden_state
, optional:hidden_states
, optional:attentions
)last_hidden_state
of shape(batch_size, sequence_length, hidden_size)
, optional) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - past_key_values (
Dict[str, np.ndarray]
, optional, returned byinit_cache
or when passing previouspast_key_values
) — Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape [batch_size, max_length]. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (<class 'transformers.models.bart.configuration_bart.BartConfig'>
) and inputs.
- last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.
Ifpast_key_values
is used only the last hidden-state of the sequences of shape(batch_size, 1, hidden_size)
is output. - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple oftuple(jnp.ndarray)
of lengthconfig.n_layers
, with each tuple having 2 tensors of shape(batch_size, num_heads, sequence_length, embed_size_per_head)
) and optionally ifconfig.is_encoder_decoder=True
2 additional tensors of shape(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally ifconfig.is_encoder_decoder=True
in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
andconfig.add_cross_attention=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.
Example:
import jax.numpy as jnp from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)
decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state
FlaxBartForCausalLM
class transformers.FlaxBartForCausalLM
( config: BartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )
Parameters
- config (BartConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
, optional, defaults tojax.numpy.float32
) — The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the givendtype
.
Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.
If you wish to change the dtype of the model parameters, see to_fp16() andto_bf16().
Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for autoregressive tasks.
This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
__call__
( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None encoder_hidden_states: typing.Optional[jax.Array] = None encoder_attention_mask: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None past_key_values: dict = None dropout_rng: <function PRNGKey at 0x7fc27ba0dd80> = None ) → transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)
Parameters
- decoder_input_ids (
jnp.ndarray
of shape(batch_size, target_sequence_length)
) — Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() andPreTrainedTokenizer.call() for details.
What are decoder input IDs?
For translation and summarization training,decoder_input_ids
should be provided. If nodecoder_input_ids
is provided, the model will create this tensor by shifting theinput_ids
to the right for denoising pre-training following the paper. - encoder_outputs (
tuple(tuple(jnp.ndarray)
) — Tuple consists of (last_hidden_state
, optional:hidden_states
, optional:attentions
)last_hidden_state
of shape(batch_size, sequence_length, hidden_size)
, optional) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (
jnp.ndarray
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
What are attention masks?
- decoder_attention_mask (
jnp.ndarray
of shape(batch_size, target_sequence_length)
, optional) — Default behavior: generate a tensor that ignores pad tokens indecoder_input_ids
. Causal mask will also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in the paper for more information on the default strategy. - decoder_position_ids (
numpy.ndarray
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range[0, config.max_position_embeddings - 1]
. - past_key_values (
Dict[str, np.ndarray]
, optional, returned byinit_cache
or when passing previouspast_key_values
) — Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape [batch_size, max_length]. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
A transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various elements depending on the configuration (BartConfig) and inputs.
- logits (
jnp.ndarray
of shape(batch_size, sequence_length, config.vocab_size)
) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (
tuple(jnp.ndarray)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (
tuple(jnp.ndarray)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.
Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads. - past_key_values (
tuple(tuple(jnp.ndarray))
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — Tuple ofjnp.ndarray
tuples of lengthconfig.n_layers
, with each tuple containing the cached key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. Only relevant ifconfig.is_decoder = True
.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding.
The FlaxBartDecoderPreTrainedModel
forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Example:
from transformers import AutoTokenizer, FlaxBartForCausalLM
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxBartForCausalLM.from_pretrained("facebook/bart-base")
inputs = tokenizer("Hello, my dog is cute", return_tensors="np") outputs = model(**inputs)
next_token_logits = outputs.logits[:, -1]