GitHub - openvinotoolkit/nncf: Neural Network Compression Framework for enhanced OpenVINO™ inference (original) (raw)
Neural Network Compression Framework (NNCF) provides a suite of post-training and training-time algorithms for optimizing inference of neural networks in OpenVINO™ with a minimal accuracy drop.
NNCF is designed to work with models from PyTorch,TorchFX,ONNX and OpenVINO™.
NNCF provides samples that demonstrate the usage of compression algorithms for different use cases and models. See compression results achievable with the NNCF-powered samples on the NNCF Model Zoo page.
The framework is organized as a Python* package that can be built and used in a standalone mode. The framework architecture is unified to make it easy to add different compression algorithms for both PyTorch deep learning frameworks.
Key Features
Post-Training Compression Algorithms
| Compression algorithm | OpenVINO | PyTorch | TorchFX | ONNX |
|---|---|---|---|---|
| Post-Training Quantization | Supported | Supported | Experimental | Supported |
| Weights Compression | Supported | Supported | Experimental | Supported |
| Activation Sparsity | Not supported | Experimental | Not supported | Not supported |
Training-Time Compression Algorithms
| Compression algorithm | PyTorch |
|---|---|
| Quantization Aware Training | Supported |
| Weight-Only Quantization Aware Training with LoRA and NLS | Supported |
| Mixed-Precision Quantization | Supported |
- Automatic, configurable model graph transformation to obtain the compressed model.
- Common interface for compression methods.
- GPU-accelerated layers for faster compressed model fine-tuning.
- Distributed training support.
- Git patch for prominent third-party repository (huggingface-transformers) demonstrating the process of integrating NNCF into custom training pipelines.
- Exporting PyTorch compressed models to ONNX* checkpoints compressed models to SavedModel or Frozen Graph format, ready to use with OpenVINO™ toolkit.
Documentation
This documentation covers detailed information about NNCF algorithms and functions needed for the contribution to NNCF.
The latest user documentation for NNCF is available here.
NNCF API documentation can be found here.
Usage
Post-Training Quantization
The NNCF PTQ is the simplest way to apply 8-bit quantization. To run the algorithm you only need your model and a small (~300 samples) calibration dataset.
OpenVINO is the preferred backend to run PTQ with, while PyTorch and ONNX are also supported.
OpenVINO
import nncf import openvino as ov import torch from torchvision import datasets, transforms
Instantiate your uncompressed model
model = ov.Core().read_model("/model_path")
Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()])) dataset_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1)
Step 1: Initialize transformation function
def transform_fn(data_item): images, _ = data_item return images
Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)
PyTorch
import nncf import torch from torchvision import datasets, models
Instantiate your uncompressed model
model = models.mobilenet_v2()
Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()])) dataset_loader = torch.utils.data.DataLoader(val_dataset)
Step 1: Initialize the transformation function
def transform_fn(data_item): images, _ = data_item return images
Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)
NOTE If the Post-Training Quantization algorithm does not meet quality requirements you can fine-tune the quantized pytorch model. You can find an example of the Quantization-Aware training pipeline for a pytorch model here.
TorchFX
import nncf import torch.fx from torchvision import datasets, models
Instantiate your uncompressed model
model = models.mobilenet_v2()
Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()])) dataset_loader = torch.utils.data.DataLoader(val_dataset)
Step 1: Initialize the transformation function
def transform_fn(data_item): images, _ = data_item return images
Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
Step 3: Export model to TorchFX
input_shape = (1, 3, 224, 224) fx_model = torch.export.export_for_training(model, args=(ex_input,)).module()
or
fx_model = torch.export.export(model, args=(ex_input,)).module()
Step 4: Run the quantization pipeline
quantized_fx_model = nncf.quantize(fx_model, calibration_dataset)
ONNX
import onnx import nncf import torch from torchvision import datasets
Instantiate your uncompressed model
onnx_model = onnx.load_model("/model_path")
Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()])) dataset_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1)
Step 1: Initialize transformation function
input_name = onnx_model.graph.input[0].name def transform_fn(data_item): images, _ = data_item return {input_name: images.numpy()}
Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(onnx_model, calibration_dataset)
Training-Time Quantization
Here is an example of Accuracy Aware Quantization pipeline where model weights and compression parameters may be fine-tuned to achieve a higher accuracy.
PyTorch
import nncf import nncf.torch import torch from torchvision import datasets, models
Instantiate your uncompressed model
model = models.mobilenet_v2()
Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()])) dataset_loader = torch.utils.data.DataLoader(val_dataset)
Step 1: Initialize the transformation function
def transform_fn(data_item): images, _ = data_item return images
Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)
Now use compressed_model as a usual torch.nn.Module
to fine-tune compression parameters along with the model weights
Save quantization modules and the quantized model parameters
checkpoint = { 'state_dict': model.state_dict(), 'nncf_config': nncf.torch.get_config(model), ... # the rest of the user-defined objects to save } torch.save(checkpoint, path_to_checkpoint)
...
Load quantization modules and the quantized model parameters
resuming_checkpoint = torch.load(path_to_checkpoint) nncf_config = resuming_checkpoint['nncf_config'] state_dict = resuming_checkpoint['state_dict']
quantized_model = nncf.torch.load_from_config(model, nncf_config, example_input) model.load_state_dict(state_dict)
... the rest of the usual PyTorch-powered training pipeline
Demos, Tutorials and Samples
For a quicker start with NNCF-powered compression, try sample notebooks and scripts presented below.
Jupyter* Notebook Tutorials and Demos
Ready-to-run Jupyter* notebook tutorials and demos are available to explain and display NNCF compression algorithms for optimizing models for inference with the OpenVINO Toolkit:
| Notebook Tutorial Name | Compression Algorithm | Backend | Domain |
|---|---|---|---|
| BERT Quantization |
Post-Training Quantization | OpenVINO | NLP |
| MONAI Segmentation Model Quantization |
Post-Training Quantization | OpenVINO | Segmentation |
| PyTorch Model Quantization | Post-Training Quantization | PyTorch | Image Classification |
| YOLOv11 Quantization with Accuracy Control | Post-Training Quantization with Accuracy Control | OpenVINO | Speech-to-Text,Object Detection |
| PyTorch Training-Time Compression | Training-Time Compression | PyTorch | Image Classification |
A list of notebooks demonstrating OpenVINO conversion and inference together with NNCF compression for models from various domains:
| Demo Model | Compression Algorithm | Backend | Domain |
|---|---|---|---|
| YOLOv8 |
Post-Training Quantization | OpenVINO | Object Detection,KeyPoint Detection,Instance Segmentation |
| EfficientSAM | Post-Training Quantization | OpenVINO | Image Segmentation |
| Segment Anything Model | Post-Training Quantization | OpenVINO | Image Segmentation |
| OneFormer | Post-Training Quantization | OpenVINO | Image Segmentation |
| CLIP | Post-Training Quantization | OpenVINO | Image-to-Text |
| BLIP | Post-Training Quantization | OpenVINO | Image-to-Text |
| Latent Consistency Model | Post-Training Quantization | OpenVINO | Text-to-Image |
| SDXL-turbo | Post-Training Quantization | OpenVINO | Text-to-Image,Image-to-Image |
| Distil-Whisper | Post-Training Quantization | OpenVINO | Speech-to-Text |
| Whisper |
Post-Training Quantization | OpenVINO | Speech-to-Text |
| MMS Speech Recognition | Post-Training Quantization | OpenVINO | Speech-to-Text |
| Grammar Error Correction | Post-Training Quantization | OpenVINO | NLP, Grammar Correction |
| LLM Instruction Following | Weight Compression | OpenVINO | NLP, Instruction Following |
| LLM Chat Bots | Weight Compression | OpenVINO | NLP, Chat Bot |
Post-Training Quantization and Weight Compression Examples
Compact scripts demonstrating quantization/weight compression and corresponding inference speed boost:
| Example Name | Compression Algorithm | Backend | Domain |
|---|---|---|---|
| OpenVINO MobileNetV2 | Post-Training Quantization | OpenVINO | Image Classification |
| OpenVINO YOLOv8 | Post-Training Quantization | OpenVINO | Object Detection |
| OpenVINO YOLOv8 QwAC | Post-Training Quantization with Accuracy Control | OpenVINO | Object Detection |
| OpenVINO Anomaly Classification | Post-Training Quantization with Accuracy Control | OpenVINO | Anomaly Classification |
| PyTorch MobileNetV2 | Post-Training Quantization | PyTorch | Image Classification |
| PyTorch SSD | Post-Training Quantization | PyTorch | Object Detection |
| TorchFX Resnet18 | Post-Training Quantization | TorchFX | Image Classification |
| ONNX MobileNetV2 | Post-Training Quantization | ONNX | Image Classification |
| ONNX YOLOv8 QwAC | Post-Training Quantization with Accuracy Control | ONNX | Object Detection |
| ONNX TinyLlama WC | Weight Compression | ONNX | LLM |
| TorchFX TinyLlama WC | Weight Compression | TorchFX | LLM |
| OpenVINO TinyLlama WC | Weight Compression | OpenVINO | LLM |
| OpenVINO TinyLlama WC with HS | Weight Compression with Hyperparameters Search | OpenVINO | LLM |
| ONNX TinyLlama WC with SE | Weight Compression with Scale Estimation | ONNX | LLM |
Quantization-Aware Training Examples
| Example Name | Compression Algorithm | Backend | Domain |
|---|---|---|---|
| PyTorch Resnet18 | Quantization-Aware Training | PyTorch | Image Classification |
| PyTorch Anomalib | Quantization-Aware Training | PyTorch | Anomaly Detection |
Third-party Repository Integration
NNCF may be easily integrated into training/evaluation pipelines of third-party repositories.
Used by
- HuggingFace Optimum Intel
NNCF is used as a compression backend within the renownedtransformersrepository in HuggingFace Optimum Intel. For instance, the command below exports the Llama-3.2-3B-Instruct model to OpenVINO format with INT4-quantized weights:
optimum-cli export openvino -m meta-llama/Llama-3.2-3B-Instruct --weight-format int4 ./Llama-3.2-3B-Instruct-int4 - Ultralytics
NNCF is integrated into the Intel OpenVINO export pipeline, enabling quantization for the exported models. - ExecuTorch
NNCF is used as primary quantization framework for the ExecuTorch OpenVINO integration. - torch.compile
NNCF is used as primary quantization framework for the torch.compile OpenVINO integration. - OpenVINO Training Extensions
NNCF is integrated into OpenVINO Training Extensions as a model optimization backend. You can train, optimize, and export new models based on available model templates as well as run the exported models with OpenVINO.
Installation Guide
For detailed installation instructions, refer to the Installation guide.
NNCF can be installed as a regular PyPI package via pip:
NNCF is also available via conda:
conda install -c conda-forge nncf
System requirements of NNCF correspond to the used backend. System requirements for each backend and the matrix of corresponding versions can be found in installation.md.
NNCF Compressed Model Zoo
List of models and compression results for them can be found at our NNCF Model Zoo page.
Citing
@article{kozlov2020neural, title = {Neural network compression framework for fast model inference}, author = {Kozlov, Alexander and Lazarevich, Ivan and Shamporov, Vasily and Lyalyushkin, Nikolay and Gorbachev, Yury}, journal = {arXiv preprint arXiv:2002.08679}, year = {2020} }
Contributing Guide
Refer to the CONTRIBUTING.md file for guidelines on contributions to the NNCF repository.
Useful links
- Documentation
- Examples
- FAQ
- Notebooks
- HuggingFace Optimum Intel
- OpenVINO Model Optimization Guide
- OpenVINO Hugging Face page
- OpenVino Performance Benchmarks page
Telemetry
NNCF as part of the OpenVINO™ toolkit collects anonymous usage data for the purpose of improving OpenVINO™ tools. You can opt-out at any time by running the following command in the Python environment where you have NNCF installed:
opt_in_out --opt_out
More information available on OpenVINO telemetry.