Quantization (original) (raw)
🤗 Optimum provides an optimum.onnxruntime
package that enables you to apply quantization on many models hosted on the Hugging Face Hub using the ONNX Runtimequantization tool.
The quantization process is abstracted via the ORTConfig and the ORTQuantizer classes. The former allows you to specify how quantization should be done, while the latter effectively handles quantization.
You can read the conceptual guide on quantization to learn about quantization. It explains the main concepts that you will be using when performing quantization with theORTQuantizer.
Quantizing a model to be used with Optimum’s CLI
The Optimum ONNX Runtime quantization tool can be used through Optimum command-line interface:
optimum-cli onnxruntime quantize --help usage: optimum-cli [] onnxruntime quantize [-h] --onnx_model ONNX_MODEL -o OUTPUT [--per_channel] (--arm64 | --avx2 | --avx512 | --avx512_vnni | --tensorrt | -c CONFIG)
options:
-h, --help show this help message and exit
--arm64 Quantization for the ARM64 architecture.
--avx2 Quantization with AVX-2 instructions.
--avx512 Quantization with AVX-512 instructions.
--avx512_vnni Quantization with AVX-512 and VNNI instructions.
--tensorrt Quantization for NVIDIA TensorRT optimizer.
-c CONFIG, --config CONFIG
ORTConfig
file to use to optimize the model.
Required arguments: --onnx_model ONNX_MODEL Path to the repository where the ONNX models to quantize are located. -o OUTPUT, --output OUTPUT Path to the directory where to store generated ONNX model.
Optional arguments: --per_channel Compute the quantization parameters on a per-channel basis.
Quantizing an ONNX model can be done as follows:
optimum-cli onnxruntime quantize --onnx_model onnx_model_location/ --avx512 -o quantized_model/
This quantize all the ONNX files in onnx_model_location
with the AVX-512 instructions.
Creating an ORTQuantizer
The ORTQuantizer class is used to quantize your ONNX model. The class can be initialized using the from_pretrained()
method, which supports different checkpoint formats.
- Using an already initialized
ORTModelForXXX
class.
from optimum.onnxruntime import ORTQuantizer, ORTModelForSequenceClassification
ort_model = ORTModelForSequenceClassification.from_pretrained( ... "optimum/distilbert-base-uncased-finetuned-sst-2-english" ... )
quantizer = ORTQuantizer.from_pretrained(ort_model)
- Using a local ONNX model from a directory.
from optimum.onnxruntime import ORTQuantizer
quantizer = ORTQuantizer.from_pretrained("path/to/model")
Apply Dynamic Quantization
The ORTQuantizer class can be used to quantize dynamically your ONNX model. Below you will find an easy end-to-end example on how to quantize dynamicallydistilbert-base-uncased-finetuned-sst-2-english.
from optimum.onnxruntime import ORTQuantizer, ORTModelForSequenceClassification from optimum.onnxruntime.configuration import AutoQuantizationConfig
onnx_model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True)
quantizer = ORTQuantizer.from_pretrained(onnx_model)
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
model_quantized_path = quantizer.quantize( ... save_dir="path/to/output/model", ... quantization_config=dqconfig, ... )
Static Quantization example
The ORTQuantizer class can be used to quantize statically your ONNX model. Below you will find an easy end-to-end example on how to quantize staticallydistilbert-base-uncased-finetuned-sst-2-english.
from functools import partial from transformers import AutoTokenizer from optimum.onnxruntime import ORTQuantizer, ORTModelForSequenceClassification from optimum.onnxruntime.configuration import AutoQuantizationConfig, AutoCalibrationConfig
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) quantizer = ORTQuantizer.from_pretrained(onnx_model) qconfig = AutoQuantizationConfig.arm64(is_static=True, per_channel=False)
def preprocess_fn(ex, tokenizer): ... return tokenizer(ex["sentence"])
calibration_dataset = quantizer.get_calibration_dataset( ... "glue", ... dataset_config_name="sst2", ... preprocess_function=partial(preprocess_fn, tokenizer=tokenizer), ... num_samples=50, ... dataset_split="train", ... )
calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)
ranges = quantizer.fit( ... dataset=calibration_dataset, ... calibration_config=calibration_config, ... operators_to_quantize=qconfig.operators_to_quantize, ... )
model_quantized_path = quantizer.quantize( ... save_dir="path/to/output/model", ... calibration_tensors_range=ranges, ... quantization_config=qconfig, ... )
Quantize Seq2Seq models
The ORTQuantizer class currently doesn’t support multi-file models, likeORTModelForSeq2SeqLM. If you want to quantize a Seq2Seq model, you have to quantize each model’s component individually.
Currently, only dynamic quantization is supported for Seq2Seq models.
- Load seq2seq model as
ORTModelForSeq2SeqLM
.
from optimum.onnxruntime import ORTQuantizer, ORTModelForSeq2SeqLM from optimum.onnxruntime.configuration import AutoQuantizationConfig
model_id = "optimum/t5-small" onnx_model = ORTModelForSeq2SeqLM.from_pretrained(model_id) model_dir = onnx_model.model_save_dir
- Define Quantizer for encoder, decoder and decoder with past keys
encoder_quantizer = ORTQuantizer.from_pretrained(model_dir, file_name="encoder_model.onnx")
decoder_quantizer = ORTQuantizer.from_pretrained(model_dir, file_name="decoder_model.onnx")
decoder_wp_quantizer = ORTQuantizer.from_pretrained(model_dir, file_name="decoder_with_past_model.onnx")
quantizer = [encoder_quantizer, decoder_quantizer, decoder_wp_quantizer]
- Quantize all models
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
for q in quantizer: ... q.quantize(save_dir=".",quantization_config=dqconfig)