Generation (original) (raw)

Each framework has a generate method for text generation implemented in their respective GenerationMixin class:

Regardless of your framework of choice, you can parameterize the generate method with a GenerationConfigclass instance. Please refer to this class for the complete list of generation parameters, which control the behavior of the generation method.

To learn how to inspect a model’s generation configuration, what are the defaults, how to change the parameters ad hoc, and how to create and save a customized generation configuration, refer to thetext generation strategies guide. The guide also explains how to use related features, like token streaming.

GenerationConfig

class transformers.GenerationConfig

< source >

( **kwargs )

Parameters that control the length of the output

Parameters that control the generation strategy used

Parameters that control the cache

Parameters for manipulation of the model output logits

Parameters that define the output variables of generate

Special tokens that can be used at generation time

Generation parameters exclusive to encoder-decoder models

Generation parameters exclusive to assistant generation

Parameters related to performances and compilation

Class that holds a configuration for a generation task. A generate call supports the following generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

To learn more about decoding strategies refer to the text generation strategies guide.

A large number of these flags control the logits or the stopping criteria of the generation. Make sure you check the generate-related classes for a full description of the possible manipulations, as well as examples of their usage.

from_pretrained

< source >

( pretrained_model_name: typing.Union[str, os.PathLike] config_file_name: typing.Union[str, os.PathLike, NoneType] = None cache_dir: typing.Union[str, os.PathLike, NoneType] = None force_download: bool = False local_files_only: bool = False token: typing.Union[bool, str, NoneType] = None revision: str = 'main' **kwargs ) → GenerationConfig

Parameters

The configuration object instantiated from this pretrained model.

Instantiate a GenerationConfig from a generation configuration file.

Examples:

from transformers import GenerationConfig

generation_config = GenerationConfig.from_pretrained("openai-community/gpt2")

generation_config.save_pretrained("./test/saved_model/") generation_config = GenerationConfig.from_pretrained("./test/saved_model/")

generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json") generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")

generation_config, unused_kwargs = GenerationConfig.from_pretrained( ... "openai-community/gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True ... ) generation_config.top_k 1

unused_kwargs {'foo': False}

from_model_config

< source >

( model_config: PretrainedConfig ) → GenerationConfig

Parameters

The configuration object instantiated from those parameters.

Instantiates a GenerationConfig from a PretrainedConfig. This function is useful to convert legacyPretrainedConfig objects, which may contain generation parameters, into a stand-alone GenerationConfig.

save_pretrained

< source >

( save_directory: typing.Union[str, os.PathLike] config_file_name: typing.Union[str, os.PathLike, NoneType] = None push_to_hub: bool = False **kwargs )

Parameters

Save a generation configuration object to the directory save_directory, so that it can be re-loaded using thefrom_pretrained() class method.

update

< source >

( **kwargs ) → Dict[str, Any]

Parameters

Dictionary containing all the key-value pairs that were not used to update the instance.

Updates attributes of this class instance with attributes from kwargs if they match existing attributes, returning all the unused kwargs.

validate

< source >

( strict = False )

Parameters

Validates the values of the attributes of the GenerationConfig instance. Raises exceptions in the presence of parameterization that can be detected as incorrect from the configuration instance alone.

Note that some parameters not validated here are best validated at generate runtime, as they may depend on other inputs and/or the model, such as parameters related to the generation length.

get_generation_mode

< source >

( assistant_model: typing.Optional[ForwardRef('PreTrainedModel')] = None ) → GenerationMode

Parameters

The generation mode triggered by the instance.

Returns the generation mode triggered by the GenerationConfig instance.

GenerationMixin

A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes. Inheriting from this class causes the model to have special generation-related behavior, such as loading aGenerationConfig at initialization time or ensuring generate-related tests are run in transformers CI.

A model class should inherit from GenerationMixin to enable calling methods like generate, or when it has defined a custom generate method that relies on GenerationMixin, directly or indirectly, which approximately shares the same interface to public methods like generate. Three examples:

The class exposes generate(), which can be used for:

To learn more about decoding strategies refer to the text generation strategies guide.

generate

< source >

( inputs: typing.Optional[torch.Tensor] = None generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None logits_processor: typing.Optional[transformers.generation.logits_process.LogitsProcessorList] = None stopping_criteria: typing.Optional[transformers.generation.stopping_criteria.StoppingCriteriaList] = None prefix_allowed_tokens_fn: typing.Optional[typing.Callable[[int, torch.Tensor], typing.List[int]]] = None synced_gpus: typing.Optional[bool] = None assistant_model: typing.Optional[ForwardRef('PreTrainedModel')] = None streamer: typing.Optional[ForwardRef('BaseStreamer')] = None negative_prompt_ids: typing.Optional[torch.Tensor] = None negative_prompt_attention_mask: typing.Optional[torch.Tensor] = None use_model_defaults: typing.Optional[bool] = None custom_generate: typing.Optional[str] = None **kwargs ) → ModelOutput or torch.LongTensor

Parameters

Returns

ModelOutput or torch.LongTensor

A ModelOutput (if return_dict_in_generate=Trueor when config.return_dict_in_generate=True) or a torch.LongTensor.

If the model is not an encoder-decoder model (model.config.is_encoder_decoder=False), the possibleModelOutput types are:

If the model is an encoder-decoder model (model.config.is_encoder_decoder=True), the possibleModelOutput types are:

Generates sequences of token ids for models with a language modeling head.

Most generation-controlling parameters are set in generation_config which, if not passed, will be set to the model’s default generation configuration. You can override any generation_config by passing the corresponding parameters to generate(), e.g. .generate(inputs, num_beams=4, do_sample=True).

For an overview of generation strategies and code examples, check out the following guide.

compute_transition_scores

< source >

( sequences: Tensor scores: typing.Tuple[torch.Tensor] beam_indices: typing.Optional[torch.Tensor] = None normalize_logits: bool = False ) → torch.Tensor

Parameters

A torch.Tensor of shape (batch_size*num_return_sequences, sequence_length) containing the transition scores (logits)

Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time.

Examples:

from transformers import GPT2Tokenizer, AutoModelForCausalLM import numpy as np

tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") tokenizer.pad_token_id = tokenizer.eos_token_id inputs = tokenizer(["Today is"], return_tensors="pt")

outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) transition_scores = model.compute_transition_scores( ... outputs.sequences, outputs.scores, normalize_logits=True ... )

input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] generated_tokens = outputs.sequences[:, input_length:] for tok, score in zip(generated_tokens[0], transition_scores[0]): ...
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") | 262 | the | -1.414 | 24.33% | 1110 | day | -2.609 | 7.36% | 618 | when | -2.010 | 13.40% | 356 | we | -1.859 | 15.58% | 460 | can | -2.508 | 8.14%

outputs = model.generate( ... **inputs, ... max_new_tokens=5, ... num_beams=4, ... num_return_sequences=4, ... return_dict_in_generate=True, ... output_scores=True, ... ) transition_scores = model.compute_transition_scores( ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False ... )

output_length = np.sum(transition_scores.numpy() < 0, axis=1) length_penalty = model.generation_config.length_penalty reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) print(np.allclose(outputs.sequences_scores, reconstructed_scores)) True

TFGenerationMixin

class transformers.TFGenerationMixin

< source >

( )

A class containing all of the functions supporting generation, to be used as a mixin in TFPreTrainedModel.

The class exposes generate(), which can be used for:

You do not need to call any of the above methods directly. Pass custom parameter values to ‘generate’ instead. To learn more about decoding strategies refer to the text generation strategies guide.

generate

< source >

( inputs: typing.Optional[tensorflow.python.framework.tensor.Tensor] = None generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None logits_processor: typing.Optional[transformers.generation.tf_logits_process.TFLogitsProcessorList] = None seed = None **kwargs ) → ModelOutput or tf.Tensor

Parameters

Returns

ModelOutput or tf.Tensor

A ModelOutput (if return_dict_in_generate=True or whenconfig.return_dict_in_generate=True) or a tf.Tensor.

If the model is not an encoder-decoder model (model.config.is_encoder_decoder=False), the possibleModelOutput types are:

If the model is an encoder-decoder model (model.config.is_encoder_decoder=True), the possibleModelOutput types are:

Generates sequences of token ids for models with a language modeling head.

Most generation-controlling parameters are set in generation_config which, if not passed, will be set to the model’s default generation configuration. You can override any generation_config by passing the corresponding parameters to generate, e.g. .generate(inputs, num_beams=4, do_sample=True).

For an overview of generation strategies and code examples, check out the following guide.

compute_transition_scores

< source >

( sequences: Tensor scores: typing.Tuple[tensorflow.python.framework.tensor.Tensor] beam_indices: typing.Optional[tensorflow.python.framework.tensor.Tensor] = None normalize_logits: bool = False ) → tf.Tensor

Parameters

A tf.Tensor of shape (batch_size*num_return_sequences, sequence_length) containing the transition scores (logits)

Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time.

Examples:

from transformers import GPT2Tokenizer, TFAutoModelForCausalLM import numpy as np

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2") tokenizer.pad_token_id = tokenizer.eos_token_id inputs = tokenizer(["Today is"], return_tensors="tf")

outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) transition_scores = model.compute_transition_scores( ... outputs.sequences, outputs.scores, normalize_logits=True ... )

input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] generated_tokens = outputs.sequences[:, input_length:] for tok, score in zip(generated_tokens[0], transition_scores[0]): ...
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") | 262 | the | -1.414 | 24.33% | 1110 | day | -2.609 | 7.36% | 618 | when | -2.010 | 13.40% | 356 | we | -1.859 | 15.58% | 460 | can | -2.508 | 8.14%

outputs = model.generate( ... **inputs, ... max_new_tokens=5, ... num_beams=4, ... num_return_sequences=4, ... return_dict_in_generate=True, ... output_scores=True, ... ) transition_scores = model.compute_transition_scores( ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False ... )

output_length = np.sum(transition_scores.numpy() < 0, axis=1) length_penalty = model.generation_config.length_penalty reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty) print(np.allclose(outputs.sequences_scores, reconstructed_scores)) True

FlaxGenerationMixin

class transformers.FlaxGenerationMixin

< source >

( )

A class containing all functions for auto-regressive text generation, to be used as a mixin inFlaxPreTrainedModel.

The class exposes generate(), which can be used for:

You do not need to call any of the above methods directly. Pass custom parameter values to ‘generate’ instead. To learn more about decoding strategies refer to the text generation strategies guide.

generate

< source >

( input_ids: Array generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None prng_key: typing.Optional[jax.Array] = None trace: bool = True params: typing.Optional[typing.Dict[str, jax.Array]] = None logits_processor: typing.Optional[transformers.generation.flax_logits_process.FlaxLogitsProcessorList] = None **kwargs )

Parameters

Generates sequences of token ids for models with a language modeling head.

< > Update on GitHub