mBART (original) (raw)

FlaxMBartModel

class transformers.FlaxMBartModel

< source >

( config: MBartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

The bare MBart Model transformer outputting raw hidden-states without any specific head on top. This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MBartConfig) and inputs.

The FlaxMBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxMBartModel

tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") model = FlaxMBartModel.from_pretrained("facebook/mbart-large-cc25")

inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: typing.Optional[dict] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state

FlaxMBartForConditionalGeneration

class transformers.FlaxMBartForConditionalGeneration

< source >

( config: MBartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

The MMBart Model with a language modeling head. Can be used for summarization. This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MBartConfig) and inputs.

The FlaxMBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Summarization example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration, MBartConfig

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen." inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")

summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5).sequences print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))

Mask filling example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

TXT = " Meine Freunde sind nett aber sie essen zu viel Kuchen. de_DE" input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="np")["input_ids"]

logits = model(input_ids).logits masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() probs = logits[0, masked_index].softmax(dim=0) values, predictions = probs.topk(5)

tokenizer.decode(predictions).split()

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: typing.Optional[dict] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) logits = outputs.logits

FlaxMBartForSequenceClassification

class transformers.FlaxMBartForSequenceClassification

< source >

( config: MBartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MBartConfig) and inputs.

The FlaxMBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxMBartForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") model = FlaxMBartForSequenceClassification.from_pretrained("facebook/mbart-large-cc25")

inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")

outputs = model(**inputs) logits = outputs.logits

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: typing.Optional[dict] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state

FlaxMBartForQuestionAnswering

class transformers.FlaxMBartForQuestionAnswering

< source >

( config: MBartConfig input_shape: typing.Tuple[int] = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits and span end logits).

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MBartConfig) and inputs.

The FlaxMBartPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, FlaxMBartForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") model = FlaxMBartForQuestionAnswering.from_pretrained("facebook/mbart-large-cc25")

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" inputs = tokenizer(question, text, return_tensors="jax")

outputs = model(**inputs) start_scores = outputs.start_logits end_scores = outputs.end_logits

encode

< source >

( input_ids: Array attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decode

< source >

( decoder_input_ids encoder_outputs encoder_attention_mask: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None past_key_values: typing.Optional[dict] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3040397370> = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.mbart.configuration_mbart.MBartConfig'>) and inputs.

Example:

from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

text = "My friends are cool but they eat too many carbs." inputs = tokenizer(text, max_length=1024, return_tensors="jax") encoder_outputs = model.encode(**inputs)

decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

outputs = model.decode(decoder_input_ids, encoder_outputs) last_decoder_hidden_states = outputs.last_hidden_state