Llama (original) (raw)

Llama is a family of large language models ranging from 7B to 65B parameters. These models are focused on efficient inference (important for serving language models) by training a smaller model on more tokens rather than training a larger model on fewer tokens. The Llama model is based on the GPT architecture, but it uses pre-normalization to improve training stability, replaces ReLU with SwiGLU to improve performance, and replaces absolute positional embeddings with rotary positional embeddings (RoPE) to better handle longer sequence lengths.

You can find all the original Llama checkpoints under the Huggy Llama organization.

Click on the Llama models in the right sidebar for more examples of how to apply Llama to different language tasks.

The example below demonstrates how to generate text with Pipeline or the AutoModel, and from the command line.

Pipeline

AutoModel

transformers CLI

import torch from transformers import pipeline

pipeline = pipeline( task="text-generation", model="huggyllama/llama-7b", torch_dtype=torch.float16, device=0 ) pipeline("Plants create energy through a process known as")

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses torchao to only quantize the weights to int4.

import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

quantization_config = TorchAoConfig("int4_weight_only", group_size=128) model = AutoModelForCausalLM.from_pretrained( "huggyllama/llama-30b", torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config )

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-30b") input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")

output = model.generate(**input_ids, cache_implementation="static") print(tokenizer.decode(output[0], skip_special_tokens=True))

Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.

from transformers.utils.attention_visualizer import AttentionMaskVisualizer

visualizer = AttentionMaskVisualizer("huggyllama/llama-7b") visualizer("Plants create energy through a process known as")

Notes

LlamaConfig

class transformers.LlamaConfig

< source >

( vocab_size = 32000 hidden_size = 4096 intermediate_size = 11008 num_hidden_layers = 32 num_attention_heads = 32 num_key_value_heads = None hidden_act = 'silu' max_position_embeddings = 2048 initializer_range = 0.02 rms_norm_eps = 1e-06 use_cache = True pad_token_id = None bos_token_id = 1 eos_token_id = 2 pretraining_tp = 1 tie_word_embeddings = False rope_theta = 10000.0 rope_scaling = None attention_bias = False attention_dropout = 0.0 mlp_bias = False head_dim = None **kwargs )

Parameters

This is the configuration class to store the configuration of a LlamaModel. It is used to instantiate an LLaMA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the LLaMA-7B.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

from transformers import LlamaModel, LlamaConfig

configuration = LlamaConfig()

model = LlamaModel(configuration)

configuration = model.config

LlamaTokenizer

class transformers.LlamaTokenizer

< source >

( vocab_file unk_token = '' bos_token = '' eos_token = '' pad_token = None sp_model_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None add_bos_token = True add_eos_token = False clean_up_tokenization_spaces = False use_default_system_prompt = False spaces_between_special_tokens = False legacy = None add_prefix_space = True **kwargs )

Parameters

Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is no padding token in the original model.

build_inputs_with_special_tokens

< source >

( token_ids_0 token_ids_1 = None )

get_special_tokens_mask

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None already_has_special_tokens: bool = False ) → List[int]

Parameters

A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.

Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer prepare_for_model method.

create_token_type_ids_from_sequences

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

List of token type IDs according to the given sequence(s).

Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT

sequence pair mask has the following format:

0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence |

if token_ids_1 is None, only returns the first portion of the mask (0s).

save_vocabulary

< source >

( save_directory filename_prefix: typing.Optional[str] = None ) → Tuple(str)

Parameters

Paths to the files saved.

Save the vocabulary and special tokens file to a directory.

LlamaTokenizerFast

class transformers.LlamaTokenizerFast

< source >

( vocab_file = None tokenizer_file = None clean_up_tokenization_spaces = False unk_token = '' bos_token = '' eos_token = '' add_bos_token = True add_eos_token = False use_default_system_prompt = False legacy = None add_prefix_space = None **kwargs )

Parameters

Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.

This uses notably ByteFallback and no normalization.

from transformers import LlamaTokenizerFast

tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer.encode("Hello this is a test") [1, 15043, 445, 338, 263, 1243]

If you want to change the bos_token or the eos_token, make sure to specify them when initializing the model, or call tokenizer.update_post_processor() to make sure that the post-processing is correctly done (otherwise the values of the first token and final token of an encoded sequence will not be correct). For more details, checkout [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.

This tokenizer inherits from PreTrainedTokenizerFast which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

build_inputs_with_special_tokens

< source >

( token_ids_0 token_ids_1 = None )

get_special_tokens_mask

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None already_has_special_tokens: bool = False ) → A list of integers in the range [0, 1]

Parameters

Returns

A list of integers in the range [0, 1]

1 for a special token, 0 for a sequence token.

Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer prepare_for_model or encode_plus methods.

create_token_type_ids_from_sequences

< source >

( token_ids_0: typing.List[int] token_ids_1: typing.Optional[typing.List[int]] = None ) → List[int]

Parameters

The token type ids.

Create the token type IDs corresponding to the sequences passed. What are token type IDs?

Should be overridden in a subclass if the model has a special way of building those.

Updates the underlying post processor with the current bos_token and eos_token.

save_vocabulary

< source >

( save_directory: str filename_prefix: typing.Optional[str] = None )

LlamaModel

class transformers.LlamaModel

< source >

( config: LlamaConfig )

Parameters

The bare LLaMA Model outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a LlamaDecoderLayer

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None **flash_attn_kwargs: typing_extensions.Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs] )

Parameters

The LlamaModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

LlamaForCausalLM

class transformers.LlamaForCausalLM

< source >

( config )

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **kwargs: typing_extensions.Unpack[transformers.models.llama.modeling_llama.KwargsForCausalLM] ) → transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutputWithPast or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LlamaConfig) and inputs.

The LlamaForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

prompt = "Hey, are you conscious? Can you talk to me?" inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = model.generate(inputs.input_ids, max_length=30) tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."

LlamaForSequenceClassification

class transformers.LlamaForSequenceClassification

< source >

( config )

Parameters

The LLaMa Model transformer with a sequence classification head on top (linear layer).

LlamaForSequenceClassification uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token. If apad_token_id is defined in the configuration, it finds the last token that is not a padding token in each row. If no pad_token_id is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when inputs_embeds are passed instead of input_ids, it does the same (take the last value in each row of the batch).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None )

Parameters

The LlamaForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

LlamaForQuestionAnswering

class transformers.LlamaForQuestionAnswering

< source >

( config )

Parameters

The Llama Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits and span end logits).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None start_positions: typing.Optional[torch.LongTensor] = None end_positions: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None **kwargs )

Parameters

The LlamaForQuestionAnswering forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

LlamaForTokenClassification

class transformers.LlamaForTokenClassification

< source >

( config )

Parameters

The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None ) → transformers.modeling_outputs.TokenClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.TokenClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LlamaConfig) and inputs.

The LlamaForTokenClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, LlamaForTokenClassification import torch

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") model = LlamaForTokenClassification.from_pretrained("meta-llama/Llama-2-7b-hf")

inputs = tokenizer( ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt" ... )

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_token_class_ids = logits.argmax(-1)

predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]

labels = predicted_token_class_ids loss = model(**inputs, labels=labels).loss

FlaxLlamaModel

class transformers.FlaxLlamaModel

< source >

( config: LlamaConfig input_shape: typing.Tuple = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

The bare Llama Model transformer outputting raw hidden-states without any specific head on top.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids attention_mask = None position_ids = None params: typing.Optional[dict] = None past_key_values: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3493177760> = None train: bool = False output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_flax_outputs.FlaxBaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxBaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LlamaConfig) and inputs.

The FlaxLlamaPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

This example uses a random model as the real ones are all very big. To get proper results, you should use openlm-research/open_llama_3b_v2 instead of afmck/testing-llama-tiny. If you get out-of-memory when loading that checkpoint, you can try adding device_map="auto" in the from_pretrained call.

Example:

from transformers import AutoTokenizer, FlaxLlamaModel

tokenizer = AutoTokenizer.from_pretrained("afmck/testing-llama-tiny") model = FlaxLlamaModel.from_pretrained("afmck/testing-llama-tiny")

inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

FlaxLlamaForCausalLM

class transformers.FlaxLlamaForCausalLM

< source >

( config: LlamaConfig input_shape: typing.Tuple = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

The Llama Model transformer with a language modeling head (linear layer) on top.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a Flax Linenflax.nn.Module subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< source >

( input_ids attention_mask = None position_ids = None params: typing.Optional[dict] = None past_key_values: typing.Optional[dict] = None dropout_rng: <function PRNGKey at 0x7f3493177760> = None train: bool = False output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_flax_outputs.FlaxMaskedLMOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_flax_outputs.FlaxMaskedLMOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (LlamaConfig) and inputs.

The FlaxLlamaPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

This example uses a random model as the real ones are all very big. To get proper results, you should use openlm-research/open_llama_3b_v2 instead of afmck/testing-llama-tiny. If you get out-of-memory when loading that checkpoint, you can try adding device_map="auto" in the from_pretrained call.

Example:

from transformers import AutoTokenizer, FlaxLlamaForCausalLM

tokenizer = AutoTokenizer.from_pretrained("afmck/testing-llama-tiny") model = FlaxLlamaForCausalLM.from_pretrained("afmck/testing-llama-tiny")

inputs = tokenizer("Hello, my dog is cute", return_tensors="np") outputs = model(**inputs)

next_token_logits = outputs.logits[:, -1]

< > Update on GitHub