GitHub - intel/auto-round: Advanced Quantization Algorithm for LLMs/VLMs. (original) (raw)
What's New
- [2025/04] AutoRound supports some recipes for Qwen3 series, please refer to Qwen3-8B-sym-recipe and Qwen3-14B-sym-recipe for more details.
- [2025/04] AutoRound has been integrated into Transformers. You can run models in the AutoRound format directly with Transformers versions later than 4.51.3.
- [2025/03] The INT2-mixed R1 model (~200GB) retains 97.9% accuracy. Check out OPEA/DeepSeek-R1-int2-mixed-sym-inc.
- [2025/01] We provide experimental support for GGUF q4_0 and q4_1 formats.
- [2024/11] We provide experimental support for VLM quantization, please check out the README
Installation
Install from pypi
GPU
pip install auto-round
CPU
pip install auto-round[cpu]
HPU
pip install auto-round-lib
Build from Source
GPU
pip install .
CPU
pip install .[cpu]
HPU
python setup.py install lib
Model Quantization
Command Line Usage (Gaudi/CPU/Intel GPU/CUDA)
A user guide detailing the full list of supported arguments is provided by calling auto-round -h
on the terminal. Set the format you want in format
and multiple formats exporting has been supported. Please check out step-by-step-instruction for more details about calibration dataset or evaluation.
auto-round
--model Qwen/Qwen3-0.6B
--bits 4
--group_size 128
--format "auto_gptq,auto_awq,auto_round"
--output_dir ./tmp_autoround
We offer two configurations, auto-round-best
and auto-round-light
, designed for optimal accuracy and improved speed, respectively. Details are as follows.
Other Recipes
best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower
auto-round-best
--model Qwen/Qwen3-0.6B
--bits 4
--group_size 128
--low_gpu_mem_usage
light accuracy, 2-3X speedup, slight accuracy drop at W4 and larger accuracy drop at W2
auto-round-light
--model Qwen/Qwen3-0.6B
--bits 4
--group_size 128 \
In conclusion, we recommend using auto-round for INT4 and auto-round-best for INT2. However, you may adjust the configuration to suit your specific requirements and available resources.
W4G128 Average Accuracy of 13 tasks and Time Cost Results(Testing was conducted on the Nvidia A100 80G using the version of PyTorch 2.6.0 with enable_torch_compile):
Model | Qwen2.5-0.5B-Instruct | Falcon3-3B | Qwen2.5-7B-Instruct | Meta-Llama-3.1-8B-Instruct | Falcon3-10B | Qwen2.5-72B-Instruct |
---|---|---|---|---|---|---|
16bits | 0.4192 | 0.5203 | 0.6470 | 0.6212 | 0.6151 | 0.7229 |
Best | 0.4137(7m) | 0.5142(23m) | 0.6426(58m) | 0.6116(65m) | 0.6092(81m) | 0.7242(575m) |
Default | 0.4129(2m) | 0.5133(6m) | 0.6441(13m) | 0.6106(13m) | 0.6080(18m) | 0.7252(118m) |
Light | 0.4052(2m) | 0.5108(3m) | 0.6453(5m) | 0.6104(6m) | 0.6063(6m) | 0.7243(37m) |
W2G64 resultsW2G64 Average Accuracy of 13 tasks and Time Cost Results(Testing was conducted on the Nvidia A100 80G using the version of PyTorch 2.6.0 with enable_torch_compile). We recommend using higher precision for the head, tail, and non-expert modules to alleviate the significant accuracy drop.
Model | Qwen2.5-0.5B-Instruct | Falcon3-3B | Qwen2.5-7B-Instruct | Falcon3-10B | Qwen2.5-72B-Instruct |
---|---|---|---|---|---|
16bits | 0.4192 | 0.5203 | 0.6470 | 0.6151 | 0.7229 |
Best | 0.2989(6m) | 0.4267(24m) | 0.5343(56m) | 0.5207(79m) | 0.6715(564m) |
Default | 0.2878(2m) | 0.4219(6m) | 0.5209(13m) | 0.5133(18m) | 0.6713(122m) |
Light | 0.2760(2m) | 0.4063(3m) | 0.4764(5m) | 0.4810(7m) | 0.6581(38m) |
API Usage (HPU/CPU/XPU/CUDA)
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3-0.6B" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_name)
from auto_round import AutoRound
bits, group_size, sym = 4, 128, True autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym)
the best accuracy, 4-5X slower, low_gpu_mem_usage could save ~20G but ~30% slower
autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym)
2-3X speedup, slight accuracy drop at W4G128
autoround = AutoRound(model, tokenizer, nsamples=128, iters=50, lr=5e-3, bits=bits, group_size=group_size, sym=sym )
output_dir = "./tmp_autoround"
format= 'auto_round'(default), 'auto_gptq', 'auto_awq'
autoround.quantize_and_save(output_dir, format='auto_round')
Detailed Hyperparameters
model
: The PyTorch model to be quantized.tokenizer
: An optional tokenizer for processing input data. If none, a dataset must be provided.bits (int)
: Number of bits for quantization (default is 4).group_size (int)
: Size of the quantization group (default is 128).sym (bool)
: Whether to use symmetric quantization (default is True).enable_quanted_input (bool)
: Whether to use the output of the previous quantized block as the input for the current block for tuning (default is True).enable_minmax_tuning (bool)
: Whether to enable weight min-max tuning (default is True).iters (int)
: Number of tuning iterations (default is 200).lr (float)
: The learning rate for rounding value (default is None, it will be set to 1.0/iters automatically).minmax_lr (float)
: The learning rate for min-max tuning (default is None, it will be set to lr automatically).nsamples (int)
: Number of samples for tuning (default is 128).seqlen (int)
: Data length of the sequence for tuning (default is 2048).batch_size (int)
: Batch size for training (default is 8).scale_dtype (str)
: The data type of quantization scale to be used (default is "float16"), different kernels have different choices.amp (bool)
: Whether to use automatic mixed precision (default is True).nblocks (int)
: Packing several blocks as one for tuning together (default is 1).gradient_accumulate_steps (int)
: Number of gradient accumulation steps (default is 1).low_gpu_mem_usage (bool)
: Whether to save GPU memory at the cost of ~20% more tuning time (default is False).dataset Union[str, list, tuple, torch.utils.data.DataLoader]
: The dataset name for tuning (default is " NeelNanda/pile-10k"). Local json file and combination of datasets have been supported, e.g. " ./tmp.json,NeelNanda/pile-10k:train, mbpp:train+validation+test"layer_config (dict)
: Configuration for weight quantization (default is None), mainly for mixed bits or mixed precision.device
: The device to be used for tuning. The default is set to 'auto', allowing for automatic detection.
API Usage for VLMs
If you encounter issues during quantization, try setting iters=0 (to enable RTN) and use group_size=32 for better results.
Click to expand
This feature is experimental and may be subject to changes, including potential bug fixes, API modifications, or adjustments to default hype-parameters
By default, AutoRoundMLLM only quantizes the text module of VLMs and uses NeelNanda/pile-10k
for calibration. To quantize the entire model, you can enable quant_nontext_module
by setting it to True, though support for this feature is limited. For more information, please refer to the AutoRoundMLLM readme.
from auto_round import AutoRoundMLLM from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer
load the model
model_name = "Qwen/Qwen2-VL-2B-Instruct" model = Qwen2VLForConditionalGeneration.from_pretrained( model_name, trust_remote_code=True, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
quantize the model
bits, group_size, sym = 4, 128, True autoround = AutoRoundMLLM(model, tokenizer, processor, bits=bits, group_size=group_size, sym=sym) autoround.quantize()
save the quantized model, set format='auto_gptq' or 'auto_awq' to use other formats
output_dir = "./tmp_autoround" autoround.save_quantized(output_dir, format='auto_round', inplace=True)
Export Formats
AutoRound Format: This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,3,4,8] bits are supported. However, it has not yet gained widespread community adoption.
AutoGPTQ Format: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the community, [2,3,4,8] bits are supported. However, the asymmetric kernel has issues that can cause considerable accuracy drops, particularly at 2-bit quantization and small models. Besides, recently 3 bits may have some accuracy issues in Transformers.
AutoAWQ Format: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted within the community, only 4-bits quantization is supported.
GGUF Format: This format is well-suited for CPU devices and is widely adopted by the community, only q4_0 and q4_1 (W4G32) is supported in our repo.
Quantization Costs
Testing was conducted on the Nvidia A100 80G using the nightly version of PyTorch 2.6.0.dev20241029+cu124. Please note that data loading and packing costs have been excluded from the evaluation. We recommend enabling torch.compile for PyTorch versions 2.6 and above.
To optimize GPU memory usage, in addition to activating low_gpu_mem_usage
, you can set gradient_accumulate_steps=8
and abatch_size=1
, though this may increase tuning time.
The 3B and 14B models were evaluated on Qwen 2.5, the 8X7B model is Mixtral, while the remaining models utilized LLaMA 3.1.
Torch version/Config W4G128 | 3B | 8B | 14B | 70B | 8X7B |
---|---|---|---|---|---|
2.6 with torch compile | 7min10GB | 12min18GB | 23min22GB | 120min42GB | 28min46GB |
2.6 with torch compile low_gpu_mem_usage=True | 12min6GB | 19min10GB | 33min11GB | 140min25GB | 38min36GB |
2.6 with torch compile low_gpu_mem_usage=True gradient_accumulate_steps=8,bs=1 | 15min3GB | 25min6GB | 45min7GB | 187min19GB | 75min36GB |
2.5 w/o torch compile | 8min10GB | 16min20GB | 30min25GB | 140min49GB | 50min49GB |
Model Inference
Please run the quantization code first
AutoRound format
CPU: pip install intel-extension-for-pytorch(much higher speed on Intel CPU) or pip install intel-extension-for-transformers,
HPU: docker image with Gaudi Software Stack is recommended. More details can be found in Gaudi Guide.
CUDA: no extra operations for sym quantization, for asym quantization, need to install auto-round from source
HPU/CPU/XPU/CUDA
Please avoid manually moving the quantized model to a different device (e.g., model.to('cpu')) during inference, as this may cause unexpected exceptions.
from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round import AutoRoundConfig ## must import for auto-round format
quantized_model_path = "./tmp_autoround" model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) text = "There is a girl who likes adventure," inputs = tokenizer(text, return_tensors="pt").to(model.device) print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
Specify backend
AutoRound automatically selects the best available backend based on the installed libraries and prompts the user to install additional libraries when a better backend is found. On CUDA, the default priority is Marlin > ExLLaMAV2 > Triton, but the final choice depends on factors such as bits, group_size, packing format compatibility, etc. And the backend may not always be the most suitable for certain devices. Please refer to the following table for the details and specify the backend you want.
Name | Devices | Bits | Dtypes | Priority | Packing format | Requirements |
---|---|---|---|---|---|---|
ipex | cpu/xpu | 4 | BF16/FP16 | 5 | gptq_zp+-1/awq | intel-extension-for-pytorch |
itrex | cpu | 2,4,8 | BF16/FP16 | 0 | gptq_zp+-1/awq | intel-extension-for-transformers |
marlin | cuda | 4,8 | BF16/FP16 | 6 | gptq/gptq_zp+-1 | gptqmodel |
exllamav2 orgptqmodel:exllamav2 | cuda | 4 | BF16/FP16 | 5 | gptq | gptqmodel |
exllamav2 orgptq:exllamav2 | cuda | 4 | FP16 | 5 | gptq_zp+-1 | auto-gptq |
gptq:cuda | cuda | 2,3,4,8 | FP16 | 0 | gptq_zp+-1 | auto-gptq |
triton | cuda | 2,4,8 | BF16/FP16 | 1 | gptq/gptq_zp+-1 | auto-round |
awq | cuda | 4 | FP16 | 5 | awq | auto-awq |
hpu | hpu | 4 | BF16 | 0 | gptq/gptq_zp+-1 | auto-round |
from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round import AutoRoundConfig
quantized_model_path = "./tmp_autoround" quantization_config = AutoRoundConfig(backend="auto") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", torch_dtype="auto", quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) text = "There is a girl who likes adventure," inputs = tokenizer(text, return_tensors="pt").to(model.device) print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
Convert GPTQ/AWQ format to AutoRound
Most GPTQ/AWQ models can be converted to the AutoRound format for better compatibility and support with Intel devices. Please note that the quantization config will be changed if the model is serialized.
from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round import AutoRoundConfig ## must import for auto-round format
model_name = "ybelkada/opt-125m-gptq-4bit" quantization_config = AutoRoundConfig() model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", torch_dtype="auto", quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_name) text = "There is a girl who likes adventure," inputs = tokenizer(text, return_tensors="pt").to(model.device) print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
Evaluation
Click to expand
auto-round --model saved_quantized_model
--eval
--task lambada_openai
--eval_bs 1
Support List
AutoRound supports basically all the major large language models.
Supported Models List
Please note that an asterisk (*) indicates third-party quantized models, which may lack accuracy data and use a different recipe. We greatly appreciate their efforts and encourage more users to share their models, as we cannot release most of the models ourselves.
VLM Support Matrix
For most VLMs, we typically support the default quantization configuration, which involves quantizing only the language component while excluding the visual component. Besides, we also support quantizing non-text modules of models that follow the Hugging Face standard, i.e., those with a typical processor, though inference may have some issues due to model architecture or kernel limitations.
√ means support, - means support to export but cannot infer, X means not support.
Integration
AutoRound has been integrated into multiple repositories.
Reference
If you find AutoRound useful for your research, please cite our paper:
@article{cheng2023optimize, title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs}, author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi}, journal={arXiv preprint arXiv:2309.05516}, year={2023} }