GPTBigCode (original) (raw)

PyTorch FlashAttention SDPA

Overview

The GPTBigCode model was proposed in SantaCoder: don’t reach for the stars! by BigCode. The listed authors are: Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra.

The abstract from the paper is the following:

The BigCode project is an open-scientific collaboration working on the responsible development of large language models for code. This tech report describes the progress of the collaboration until December 2022, outlining the current state of the Personally Identifiable Information (PII) redaction pipeline, the experiments conducted to de-risk the model architecture, and the experiments investigating better preprocessing methods for the training data. We train 1.1B parameter models on the Java, JavaScript, and Python subsets of The Stack and evaluate them on the MultiPL-E text-to-code benchmark. We find that more aggressive filtering of near-duplicates can further boost performance and, surprisingly, that selecting files from repositories with 5+ GitHub stars deteriorates performance significantly. Our best model outperforms previous open-source multilingual code generation models (InCoder-6.7B and CodeGen-Multi-2.7B) in both left-to-right generation and infilling on the Java, JavaScript, and Python portions of MultiPL-E, despite being a substantially smaller model. All models are released under an OpenRAIL license at this https URL.

The model is an optimized GPT2 model with support for Multi-Query Attention.

Implementation details

The main differences compared to GPT2.

You can read more about the optimizations in the original pull request

Combining Starcoder and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

pip install -U flash-attn --no-build-isolation

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16“)

To load and run a model using Flash Attention 2, refer to the snippet below:

import torch from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda"

model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2") tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")

prompt = "def hello_world():"

model_inputs = tokenizer([prompt], return_tensors="pt").to(device) model.to(device)

generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False) tokenizer.batch_decode(generated_ids)[0] 'def hello_world():\n print("hello world")\n\nif name == "main":\n print("hello world")\n<|endoftext|>'

Expected speedups

Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using bigcode/starcoder checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.

GPTBigCodeConfig

class transformers.GPTBigCodeConfig

< source >

( vocab_size = 50257 n_positions = 1024 n_embd = 768 n_layer = 12 n_head = 12 n_inner = None activation_function = 'gelu_pytorch_tanh' resid_pdrop = 0.1 embd_pdrop = 0.1 attn_pdrop = 0.1 layer_norm_epsilon = 1e-05 initializer_range = 0.02 scale_attn_weights = True use_cache = True bos_token_id = 50256 eos_token_id = 50256 attention_softmax_in_fp32 = True scale_attention_softmax_in_fp32 = True multi_query = True **kwargs )

Parameters

This is the configuration class to store the configuration of a GPTBigCodeModel. It is used to instantiate a GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the GPTBigCodegpt_bigcode architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import GPTBigCodeConfig, GPTBigCodeModel

configuration = GPTBigCodeConfig()

model = GPTBigCodeModel(configuration)

configuration = model.config

GPTBigCodeModel

class transformers.GPTBigCodeModel

< source >

( config )

Parameters

The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.List[torch.Tensor]] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) → transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GPTBigCodeConfig) and inputs.

The GPTBigCodeModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, GPTBigCodeModel import torch

tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") model = GPTBigCodeModel.from_pretrained("bigcode/gpt_bigcode-santacoder")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

GPTBigCodeForCausalLM

class transformers.GPTBigCodeForCausalLM

< source >

( config )

Parameters

The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.Tensor]]] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None **kwargs ) → transformers.modeling_outputs.CausalLMOutputWithCrossAttentions or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.CausalLMOutputWithCrossAttentions or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GPTBigCodeConfig) and inputs.

The GPTBigCodeForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

import torch from transformers import AutoTokenizer, GPTBigCodeForCausalLM

tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") model = GPTBigCodeForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") outputs = model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss logits = outputs.logits

GPTBigCodeForSequenceClassification

class transformers.GPTBigCodeForSequenceClassification

< source >

( config )

Parameters

The GPTBigCode Model transformer with a sequence classification head on top (linear layer).

GPTBigCodeForSequenceClassification uses the last token in order to do the classification, as other causal models (e.g. GPT-1) do.

Since it does classification on the last token, it requires to know the position of the last token. If apad_token_id is defined in the configuration, it finds the last token that is not a padding token in each row. If no pad_token_id is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when inputs_embeds are passed instead of input_ids, it does the same (take the last value in each row of the batch).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.Tensor]]] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None )

Parameters

The GPTBigCodeForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

GPTBigCodeForTokenClassification

class transformers.GPTBigCodeForTokenClassification

< source >

( config )

Parameters

GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: typing.Optional[torch.Tensor] = None past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.Tensor]]] = None attention_mask: typing.Optional[torch.Tensor] = None token_type_ids: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.Tensor] = None head_mask: typing.Optional[torch.Tensor] = None inputs_embeds: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.Tensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None )

Parameters

The GPTBigCodeForTokenClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

< > Update on GitHub