Mixtral (original) (raw)

Overview

Mixtral-8x7B was introduced in the Mixtral of Experts blogpost by Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.

The introduction of the blog post says:

Today, the team is proud to release Mixtral 8x7B, a high-quality sparse mixture of experts models (SMoE) with open weights. Licensed under Apache 2.0. Mixtral outperforms Llama 2 70B on most benchmarks with 6x faster inference. It is the strongest open-weight model with a permissive license and the best model overall regarding cost/performance trade-offs. In particular, it matches or outperforms GPT3.5 on most standard benchmarks.

Mixtral-8x7B is the second large language model (LLM) released by mistral.ai, after Mistral-7B.

Architectural details

Mixtral-8x7B is a decoder-only Transformer with the following architectural choices:

The following implementation details are shared with Mistral AI’s first model Mistral-7B:

For more details refer to the release blog post.

License

Mixtral-8x7B is released under the Apache 2.0 license.

Usage tips

The Mistral team has released 2 checkpoints:

The base model can be used as follows:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

prompt = "My favourite condiment is"

model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") model.to(device)

generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) tokenizer.batch_decode(generated_ids)[0] "My favourite condiment is to ..."

The instruction tuned model can be used as follows:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", device_map="auto") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")

messages = [ ... {"role": "user", "content": "What is your favourite condiment?"}, ... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, ... {"role": "user", "content": "Do you have mayonnaise recipes?"} ... ]

model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True) tokenizer.batch_decode(generated_ids)[0] "Mayonnaise can be made as follows: (...)"

As can be seen, the instruction-tuned model requires a chat template to be applied to make sure the inputs are prepared in the right format.

Speeding up Mixtral by using Flash Attention

The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging Flash Attention, which is a faster implementation of the attention mechanism used inside the model.

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

pip install -U flash-attn --no-build-isolation

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the flash attention repository. Make also sure to load your model in half-precision (e.g. torch.float16)

To load and run a model using Flash Attention-2, refer to the snippet below:

import torch from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

prompt = "My favourite condiment is"

model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda") model.to(device)

generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) tokenizer.batch_decode(generated_ids)[0] "The expected output"

Expected speedups

Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using mistralai/Mixtral-8x7B-v0.1 checkpoint and the Flash Attention 2 version of the model.

Sliding window Attention

The current implementation supports the sliding window attention mechanism and memory efficient cache management. To enable sliding window attention, just make sure to have a flash-attn version that is compatible with sliding window attention (>=2.3.0).

The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (self.config.sliding_window), support batched generation only for padding_side="left" and use the absolute position of the current token to compute the positional embedding.

Shrinking down Mixtral using quantization

As the Mixtral model has 45 billion parameters, that would require about 90GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using quantization. If the model is quantized to 4 bits (or half a byte per parameter), a single A100 with 40GB of RAM is enough to fit the entire model, as in that case only about 27 GB of RAM is required.

Quantizing a model is as simple as passing a quantization_config to the model. Below, we’ll leverage the bitsandbytes quantization library (but refer to this page for alternative quantization methods):

import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig( ... load_in_4bit=True, ... bnb_4bit_quant_type="nf4", ... bnb_4bit_compute_dtype="torch.float16", ... )

model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", quantization_config=True, device_map="auto") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")

prompt = "My favourite condiment is"

messages = [ ... {"role": "user", "content": "What is your favourite condiment?"}, ... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, ... {"role": "user", "content": "Do you have mayonnaise recipes?"} ... ]

model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True) tokenizer.batch_decode(generated_ids)[0] "The expected output"

This model was contributed by Younes Belkada and Arthur Zucker . The original code can be found here.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Mixtral. If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

MixtralConfig

class transformers.MixtralConfig

< source >

( vocab_size = 32000 hidden_size = 4096 intermediate_size = 14336 num_hidden_layers = 32 num_attention_heads = 32 num_key_value_heads = 8 hidden_act = 'silu' max_position_embeddings = 131072 initializer_range = 0.02 rms_norm_eps = 1e-05 use_cache = True pad_token_id = None bos_token_id = 1 eos_token_id = 2 tie_word_embeddings = False rope_theta = 1000000.0 sliding_window = None attention_dropout = 0.0 num_experts_per_tok = 2 num_local_experts = 8 output_router_logits = False router_aux_loss_coef = 0.001 router_jitter_noise = 0.0 **kwargs )

Parameters

This is the configuration class to store the configuration of a MixtralModel. It is used to instantiate an Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.

mixtralai/Mixtral-8x7B mixtralai/Mixtral-7B-Instruct-v0.1

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

from transformers import MixtralModel, MixtralConfig

configuration = MixtralConfig()

model = MixtralModel(configuration)

configuration = model.config

MixtralModel

class transformers.MixtralModel

< source >

( config: MixtralConfig )

Parameters

The bare Mixtral Model outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a MixtralDecoderLayer

forward

< source >

( input_ids: LongTensor = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Optional = None inputs_embeds: Optional = None use_cache: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None output_router_logits: Optional = None return_dict: Optional = None cache_position: Optional = None )

Parameters

The MixtralModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

MixtralForCausalLM

class transformers.MixtralForCausalLM

< source >

( config )

forward

< source >

( input_ids: LongTensor = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Optional = None inputs_embeds: Optional = None labels: Optional = None use_cache: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None output_router_logits: Optional = None return_dict: Optional = None cache_position: Optional = None num_logits_to_keep: int = 0 **loss_kwargs ) → transformers.modeling_outputs.MoeCausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

Returns

transformers.modeling_outputs.MoeCausalLMOutputWithPast or tuple(torch.FloatTensor)

A transformers.modeling_outputs.MoeCausalLMOutputWithPast or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MixtralConfig) and inputs.

The MixtralForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, MixtralForCausalLM

model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

prompt = "Hey, are you conscious? Can you talk to me?" inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = model.generate(inputs.input_ids, max_length=30) tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."

MixtralForSequenceClassification

class transformers.MixtralForSequenceClassification

< source >

( config )

Parameters

The Mixtral Model transformer with a sequence classification head on top (linear layer).

MixtralForSequenceClassification uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token. If apad_token_id is defined in the configuration, it finds the last token that is not a padding token in each row. If no pad_token_id is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when inputs_embeds are passed instead of input_ids, it does the same (take the last value in each row of the batch).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Union = None inputs_embeds: Optional = None labels: Optional = None use_cache: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None )

Parameters

The MixtralForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

MixtralForTokenClassification

class transformers.MixtralForTokenClassification

< source >

( config )

Parameters

The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Optional = None inputs_embeds: Optional = None labels: Optional = None use_cache: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) → transformers.modeling_outputs.TokenClassifierOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.TokenClassifierOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (MixtralConfig) and inputs.

The MixtralForTokenClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

from transformers import AutoTokenizer, MixtralForTokenClassification import torch

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") model = MixtralForTokenClassification.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

inputs = tokenizer( ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt" ... )

with torch.no_grad(): ... logits = model(**inputs).logits

predicted_token_class_ids = logits.argmax(-1)

predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]

labels = predicted_token_class_ids loss = model(**inputs, labels=labels).loss

MixtralForQuestionAnswering

class transformers.MixtralForQuestionAnswering

< source >

( config )

Parameters

The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start logits and span end logits).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( input_ids: Optional = None attention_mask: Optional = None position_ids: Optional = None past_key_values: Union = None inputs_embeds: Optional = None start_positions: Optional = None end_positions: Optional = None output_attentions: Optional = None output_hidden_states: Optional = None return_dict: Optional = None **kwargs )

Parameters

The MixtralForQuestionAnswering forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

< > Update on GitHub