Gemma2 (original) (raw)

PyTorch FlashAttention SDPA

Overview

The Gemma2 model was proposed in Gemma2: Open Models Based on Gemini Technology and Research by Gemma2 Team, Google. Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B).

The abstract from the blog post is the following:

Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.

Tips:

This model was contributed by Arthur Zucker, Pedro Cuenca and Tom Arsen.

Gemma2Config

class transformers.Gemma2Config

< source >

( vocab_size = 256000 hidden_size = 2304 intermediate_size = 9216 num_hidden_layers = 26 num_attention_heads = 8 num_key_value_heads = 4 head_dim = 256 hidden_activation = 'gelu_pytorch_tanh' max_position_embeddings = 8192 initializer_range = 0.02 rms_norm_eps = 1e-06 use_cache = True pad_token_id = 0 eos_token_id = 1 bos_token_id = 2 tie_word_embeddings = True rope_theta = 10000.0 attention_bias = False attention_dropout = 0.0 query_pre_attn_scalar = 256 sliding_window = 4096 final_logit_softcapping = 30.0 attn_logit_softcapping = 50.0 cache_implementation = 'hybrid' **kwargs )

Parameters

This is the configuration class to store the configuration of a Gemma2Model. It is used to instantiate an Gemma2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma2-7B. e.g. google/gemma2-7bConfiguration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

from transformers import Gemma2Model, Gemma2Config

configuration = Gemma2Config()

model = Gemma2Model(configuration)

configuration = model.config

Gemma2Model

class transformers.Gemma2Model

< source >

( config: Gemma2Config )

Parameters

The bare Gemma2 Model outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a Gemma2DecoderLayer

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.HybridCache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None last_cache_position: typing.Optional[int] = None **flash_attn_kwargs: typing_extensions.Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs] )

Parameters

The Gemma2Model forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Gemma2ForCausalLM

class transformers.Gemma2ForCausalLM

< source >

( config )

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.HybridCache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **loss_kwargs ) → transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutputWithPast or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Gemma2Config) and inputs.

The Gemma2ForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, Gemma2ForCausalLM

model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")

prompt = "What is your favorite condiment?" inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = model.generate(inputs.input_ids, max_length=30) tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?"

Gemma2ForSequenceClassification

class transformers.Gemma2ForSequenceClassification

< source >

( config )

Parameters

The Gemma2 Model transformer with a sequence classification head on top (linear layer).

Gemma2ForSequenceClassification uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token. If apad_token_id is defined in the configuration, it finds the last token that is not a padding token in each row. If no pad_token_id is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when inputs_embeds are passed instead of input_ids, it does the same (take the last value in each row of the batch).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None )

Parameters

The Gemma2ForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Gemma2ForTokenClassification

class transformers.Gemma2ForTokenClassification

< source >

( config )

Parameters

The Gemma2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None ) → transformers.modeling_outputs.TokenClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.TokenClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (Gemma2Config) and inputs.

The Gemma2ForTokenClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, Gemma2ForTokenClassification import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma2-7b") model = Gemma2ForTokenClassification.from_pretrained("google/gemma2-7b")

inputs = tokenizer( ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt" ... )

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_token_class_ids = logits.argmax(-1)

predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]

labels = predicted_token_class_ids loss = model(**inputs, labels=labels).loss

< > Update on GitHub