chore(deps): bump transformers from 4.33.2 to 4.36.0 in /tools/perf by dependabot[bot] · Pull Request #2555 · pytorch/TensorRT (original) (raw)

v4.36: Mixtral, Llava/BakLlava, SeamlessM4T v2, AMD ROCm, F.sdpa wide-spread support

New model additions

Mixtral

Mixtral is the new open-source model from Mistral AI announced by the blogpost Mixtral of Experts. The model has been proven to have comparable capabilities to Chat-GPT according to the benchmark results shared on the release blogpost.

The architecture is a sparse Mixture of Experts with Top-2 routing strategy, similar as NllbMoe architecture in transformers. You can use it through AutoModelForCausalLM interface:

import torch from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B", torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-8x7B") prompt = "My favourite condiment is" model_inputs = tokenizer([prompt], return_tensors="pt").to(device) model.to(device) generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) tokenizer.batch_decode(generated_ids)[0]

The model is compatible with existing optimisation tools such Flash Attention 2, bitsandbytes and PEFT library. The checkpoints are release under mistralai organisation on the Hugging Face Hub.

Llava / BakLlava

Llava is an open-source chatbot trained by fine-tuning LlamA/Vicuna on GPT-generated multimodal instruction-following data. It is an auto-regressive language model, based on the transformer architecture. In other words, it is an multi-modal version of LLMs fine-tuned for chat / instructions.

The Llava model was proposed in Improved Baselines with Visual Instruction Tuning by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee.

The integration also includes BakLlava which is a Llava model trained with Mistral backbone.

The mode is compatible with "image-to-text" pipeline:

from transformers import pipeline from PIL import Image
import requests model_id = "llava-hf/llava-1.5-7b-hf"

... (truncated)