Quantization — coremltools API Reference 8.3 documentation (original) (raw)

Quantization refers to techniques for performing neural network computations in lower precision than floating point. Quantization can reduce a model’s size and also improve a model’s inference latency and memory bandwidth requirement, because many hardware platforms offer high-performance implementations of quantized operations.

class coremltools.optimize.torch.quantization.ModuleLinearQuantizerConfig(algorithm: str = 'vanilla', weight_dtype: str | dtype = torch.qint8, weight_observer=ObserverType.moving_average_min_max, weight_per_channel: bool = True, activation_dtype: str | dtype = torch.quint8, activation_observer=ObserverType.moving_average_min_max, quantization_scheme=QuantizationScheme.symmetric, milestones: List[int] | None = None)[source]

Configuration class for specifying global and module-level quantization options for linear quantization algorithm implemented in LinearQuantizer.

Linear quantization algorithm simulates the effects of quantization during training, by quantizing and dequantizing the weights and/or activations during the model’s forward pass. The forward and backward pass computations are conducted in float dtype, however, these float values follow the constraints imposed by int8 and quint8 dtypes. For more details, please refer toQuantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference.

For most applications, the only parameters that need to be set are quantization_scheme andmilestones.

By default, quantization_scheme is set to QuantizationScheme.symmetric, which means all weights are quantized with zero point as zero, and activations are quantized with zero point as zero for non-negative activations and 128 for all other activations. The weights are quantized using torch.qint8and activations are quantized using torch.quint8.

Linear quantization algorithm inserts observers for each weight/activation tensor. These observers collect statistics of these tensors’ values, for example, the minimum and maximum values they can take. These statistics are then used to compute the scale and zero point, which are in turn used for quantizing the weights/activations. By default, moving_average_min_max observer is used. For more details, please check MinMaxObserver.

The milestones parameter controls the flow of the quantization algorithm. The example below illustrates its usage in more detail:

model = define_model()

config = LinearQuantizerConfig( global_config=ModuleLinearQuantizerConfig( quantization_scheme="symmetric", milestones=[0, 100, 300, 200], ) )

quantizer = LinearQuantizer(model, config)

prepare the model to insert FakeQuantize layers for QAT

model = quantizer.prepare()

use quantizer in your PyTorch training loop

for inputs, labels in data: output = model(inputs) loss = loss_fn(output, labels) loss.backward() optimizer.step() quantizer.step()

In this example, from step 0 onwards, observers will collect statistics

of the values of weights/activations. However, between steps 0 and 100,

effects of quantization will not be simulated. At step 100, quantization

simulation will begin and at step 300, observer statistics collection will

stop. A batch norm layer computes mean and variance of input batch for normalizing

it during training, and collects running estimates of its computed mean and variance,

which are then used for normalization during evaluation. At step 200, batch norm

statistics collection is frozen, and the batch norm layers switch to evaluation

mode, thus more closely simulating the inference numerics during training time.

Parameters:

as_dict() → Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict)[source]

Create class from a dictionary of string keys and values.

Parameters:

data_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) → DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a strpath to the yaml file.

class coremltools.optimize.torch.quantization.LinearQuantizerConfig(global_config: ModuleLinearQuantizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModuleLinearQuantizerConfig | None] = NOTHING, non_traceable_module_names: List[str] = [], preserved_attributes: List[str] = NOTHING)[source]

Configuration class for specifying how different submodules of a model are quantized by LinearQuantizer.

In order to disable quantizing a layer or an operation, module_type_config ormodule_name_config corresponding to that operation can be set to None.

For example:

The following config will enable weight only quantization for all layers:

config = LinearQuantizerConfig.from_dict( { "global_config": { "activation_dtype": "float32", } } )

The following config will disable quantization for all linear layers and

set quantization mode to weight only quantization for convolution layers:

config = LinearQuantizerConfig.from_dict( { "module_type_configs": { "Linear": None, "Conv2d": { "activation_dtype": "float32", }, } } )

The following config will disable quantization for layers named conv1 and conv2:

config = LinearQuantizerConfig.from_dict( { "module_name_configs": { "conv1": None, "conv2": None, } } )

If model has some methods and attributes which are not used in the forward

pass, but are needed to be preserved after quantization is added, they can

be preserved on the quantized model by passing them in preserved_attributes

parameter

model = MyModel() model.key_1 = value_1 model.key_2 = value_2

config = LinearQuantizerConfig.from_dict({"preserved_attributes": ["key_1", "key_2"]})

Parameters:

Note

The quantization_scheme parameter must be the same across all configs.

as_dict() → Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) → LinearQuantizerConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) → DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a strpath to the yaml file.

set_global(global_config: ModuleOptimizationConfig | None) → OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: ModuleOptimizationConfig | None) → OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Callable | str, opt_config: ModuleOptimizationConfig | None) → OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.quantization.LinearQuantizer(model: Module, config: LinearQuantizerConfig | None = None)[source]

Perform quantization aware training (QAT) of models. This algorithm simulates the effects of quantization during training, by quantizing and dequantizing the weights and/or activations during the model’s forward pass. The forward and backward pass computations are conducted in float dtype, however, these float values follow the constraints imposed by int8 and quint8 dtypes. Thus, this algorithm adjusts the model’s weights while closely simulating the numerics which get executed during quantized inference, allowing model’s weights to adjust to quantization constraints.

For more details, please refer to Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference.

Example

import torch.nn as nn from coremltools.optimize.torch.quantization import ( LinearQuantizer, LinearQuantizerConfig, )

model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) )

loss_fn = define_loss()

initialize the quantizer

config = LinearQuantizerConfig.from_dict( { "global_config": { "quantization_scheme": "symmetric", "milestones": [0, 100, 400, 400], } } )

quantizer = LinearQuantizer(model, config)

prepare the model to insert FakeQuantize layers for QAT

model = quantizer.prepare()

use quantizer in your PyTorch training loop

for inputs, labels in data: output = model(inputs) loss = loss_fn(output, labels) loss.backward() optimizer.step() quantizer.step()

convert operations to their quantized counterparts using parameters learned via QAT

model = quantizer.finalize(inplace=True)

Parameters:

finalize(model: Module | None = None, inplace: bool = False) → Module[source]

Prepares the model for export.

Parameters:

Note

Once the model is finalized with in_place = True, it may not be runnable on the GPU.

prepare(example_inputs: Tuple[Any, ...], inplace: bool = False) → Module[source]

Prepares the model for quantization aware training by insertingtorch.ao.quantization.FakeQuantize layers in the model in appropriate places.

Parameters:

Note

This method uses prepare_qat_fx methodto insert quantization layers and the returned model is a torch.fx.GraphModule. Some models, like those with dynamic control flow, may not be trace-able into atorch.fx.GraphModule. Please follow directions in Limitations of Symbolic Tracingto update your model first before using LinearQuantizer algorithm.

report() → _Report[source]

Returns a dictionary with important statistics related to current state of quantization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics such as scale, zero point, number of parameters, and so on.

Note that error will be NaN and #params will be -1 for activations.

step()[source]

Steps through the milestones defined for this quantizer.

The first milestone corresponds to enabling observers, the second to enabling fake quantization simulation, the third to disabling observers, and the last to freezing batch norm statistics.

Note

If milestones argument is set as None, this method is a no-op.

Note

In order to not use a particular milestone, its value can be set as -1.

class coremltools.optimize.torch.quantization.ObserverType(value)[source]

An enum indicating the type of observer. Allowed options are moving_average_min_max, min_max, ema_min_max.

class coremltools.optimize.torch.quantization.QuantizationScheme(value)[source]

An enum indicating the type of quantization to be performed. Allowed options are symmetric and affine.

class coremltools.optimize.torch.quantization.ModulePostTrainingQuantizerConfig(weight_dtype: str | dtype = 'int8', granularity='per_channel', quantization_scheme=QuantizationScheme.symmetric, block_size=None)[source]

Configuration class for specifying global and module-level quantizer options forPostTrainingQuantizer algorithm.

Parameters:

This class supports three different configurations to structure the quantization:

1. Per-channel quantization: This is the default configuration where granularity is per_channel andblock_size is None. In this configuration, quantization parameters are computed for each output channel.

2. Per-tensor quantization: In this configuration, quantization parameters are computed for the tensor as a whole. That is, all values in the tensor will share a single scale and, if applicable, a single zero point. The granularity argument is set to per_tensor.

3. Per-block quantization: This configuration is used to structure the tensor for blockwise quantization. The granularityis set to per_block, and the block_size argument has to be specified. The block_size argument can either be of typeint or tuple:

Note

When performing 4-bit quantization, weight_dtype is set to torch.int8 for int4 ortorch.uint8 for uint4. This is because PyTorch currently doesn’t provide support for 4-bit data types. However, the quantization range is set according to 4-bit quantization and based on whether the weight_dtype is signed or unsigned.

class coremltools.optimize.torch.quantization.PostTrainingQuantizer(model: Module, config: PostTrainingQuantizerConfig | None = None)[source]

Perform post-training quantization on a torch model. After quantization, weights of all submodules selected for quantization contain full precision values obtained by quantizing and dequantizing the original weights, which captures the error induced by quantization.

Note

After quantization, the weight values stored will still remain in full precision, so the PyTorch model size will not be reduced. To see the reduction in model size, please convert the model using coremltools.convert(...), which will produce a model intermediate language (MIL) model containing the compressed weights.

Example:

import torch.nn as nn from coremltools.optimize.torch.quantization import ( PostTrainingQuantizerConfig, PostTrainingQuantizer, )

model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) )

initialize the quantizer

config = PostTrainingquantizerConfig.from_dict( { "global_config": { "weight_dtype": "int8", }, } )

ptq = PostTrainingQuantizer(model, config) quantized_model = ptq.compress()

Parameters:

class coremltools.optimize.torch.quantization.PostTrainingQuantizerConfig(global_config: ModulePostTrainingQuantizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModulePostTrainingQuantizerConfig | None] = NOTHING)[source]

Configuration class for specifying how different submodules of a model should be post-training quantized by PostTrainingQuantizer.

Parameters:

class coremltools.optimize.torch.layerwise_compression.LayerwiseCompressorConfig(layers: List[Module | str] | ModuleList | None = None, global_config: LayerwiseCompressionAlgorithmConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, LayerwiseCompressionAlgorithmConfig | None] = NOTHING, input_cacher: Any = 'default', calibration_nsamples: int = 128)[source]

Configuration class for specifying how different submodules of a model are compressed by LayerwiseCompressor. Note that only sequential models are supported.

Parameters:

as_dict() → Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) → LayerwiseCompressorConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) → DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a strpath to the yaml file.

class coremltools.optimize.torch.layerwise_compression.LayerwiseCompressor(model: Module, config: LayerwiseCompressorConfig)[source]

A post-training compression algorithm which compresses a sequential model layer by layer by minimizing the quantization error while quantizing the weights. The implementation supports two variations of this algorithm:

  1. Generative Pre-Trained Transformer Quantization (GPTQ)
  2. Sparse Generative Pre-Trained Transformer (SparseGPT)

At a high level, it compresses weights of a model layer by layer by minimizing the L2 norm of the difference between the original activations and activations obtained from compressing the weights of a layer. The activations are computed using a few samples of training data.

Only sequential models are supported, where the output of one layer feeds into the input of the next layer.

For HuggingFace models, disable the use_cache config. This is used to speed up decoding, but to generalize forward pass for LayerwiseCompressor algorithms across all model types, the behavior must be disabled.

Example

import torch.nn as nn from coremltools.optimize.torch.layerwise_compression import ( LayerwiseCompressor, LayerwiseCompressorConfig, )

model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) )

dataloder = load_calibration_data()

initialize the quantizer

config = LayerwiseCompressorConfig.from_dict( { "global_config": { "algorithm": "gptq", "weight_dtype": "int4", }, "input_cacher": "default", "calibration_nsamples": 16, } )

compressor = LayerwiseCompressor(model, config)

compressed_model = compressor.compress(dataloader)

Parameters:

compress(dataloader: Iterable, device: str, inplace: bool = False) → Module[source]

Compresses model using samples from dataloader.

Parameters:

GPTQ

class coremltools.optimize.torch.layerwise_compression.algorithms.ModuleGPTQConfig(weight_dtype: str | dtype = 'uint8', granularity='per_channel', quantization_scheme='symmetric', block_size: int | None = None, enable_normal_float: bool = False, hessian_dampening: float = 0.01, use_activation_order_heuristic: bool = False, processing_group_size: int = 128, algorithm: str = 'gptq')[source]

Bases: LayerwiseCompressionAlgorithmConfig

Configuration class for specifying global and module-level compression options for theGenerative Pre-Trained Transformer Quantization (GPTQ) algorithm.

Parameters:

class coremltools.optimize.torch.layerwise_compression.algorithms.GPTQ(layer: Module, config: ModuleGPTQConfig)[source]

Bases: OBSCompressionAlgorithm

A post-training compression algorithm based on the paperGPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.

Parameters: