Palettization — coremltools API Reference 8.3 documentation (original) (raw)

Palettization is a mechanism for compressing a model by clustering the model’s float weights into a lookup table (LUT) of centroids and indices.

Palettization is implemented as an extension of PyTorch’s QATAPIs. It works by inserting palettization layers in appropriate places inside a model. The model can then be fine-tuned to learn the new palettized layers’ weights in the form of a LUT and indices.

class coremltools.optimize.torch.palettization.ModuleDKMPalettizerConfig(n_bits: int | None = None, weight_threshold: int = 2048, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, enable_per_channel_scale: bool = False, milestone: int = 0, cluster_dim: int | None = None, quant_min: int = -128, quant_max: int = 127, dtype: str | dtype = torch.qint8, lut_dtype: str = 'f32', quantize_activations: bool = False, cluster_permute: tuple | None = None, palett_max_mem: float = 1.0, kmeans_max_iter: int = 3, prune_threshold: float = 1e-07, kmeans_init: str = 'auto', kmeans_opt1d_threshold: int = 1024, enforce_zero: bool = False, palett_mode: str = 'dkm', palett_tau: float = 0.0001, palett_epsilon: float = 0.0001, palett_lambda: float = 0.0, add_extra_centroid: bool = False, palett_cluster_tol: float = 0.0, palett_min_tsize: int = 65536, palett_unique: bool = False, palett_shard: bool = False, palett_batch_mode: bool = False, palett_dist: bool = False, per_channel_scaling_factor_scheme: str = 'min_max', percentage_palett_enable: float = 1.0, kmeans_batch_threshold: int = 4, kmeans_n_init: int = 10, zero_threshold: float = 1e-07, kmeans_error_bnd: float = 0.0, partition_size: int | None = None, cluster_dtype: str | None = None)[source]

Configuration class for specifying global and module-level options for the palettization algorithm implemented in DKMPalettizer.

The parameters specified in this config control the DKM algorithm, described inDKM: Differentiable K-Means Clustering Layer for Neural Network Compression.

For most use cases, the only parameters you need to specify are n_bits, weight_threshold, and milestone.

Note

Most of the parameters in this class are meant for advanced use cases and for further fine-tuning the DKM algorithm. The default values usually work for a majority of tasks.

Note

Change the following parameters only when you use activation quantization in conjunction with DKM weight palettization: quant_min, quant_max, dtype, and quantize_activations.

Parameters:

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor and group_size is None.

2. Per-grouped-channel palettization: In this configuration, group_size number of channels alongchannel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we providegroup_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along the output channel axis.

as_dict() → Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(data_dict: Dict[str, Any]) → DictableDataClass

Create class from a dictionary of string keys and values.

Parameters:

data_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) → DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a strpath to the yaml file.

class coremltools.optimize.torch.palettization.DKMPalettizerConfig(global_config: GlobalConfigType | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: ModuleNameConfigType = NOTHING)[source]

Configuration for specifying how different submodules of a model are palettized byDKMPalettizer.

The module_type_configs parameter can accept a list of ModuleDKMPalettizerConfigas values for a given module type. The list can specify different parameters for different weight_threshold values. This is useful if you want to apply different configs to layers of the same type with weights of different sizes.

For example, to use 4 -bit palettization for weights with more than 1000 elements and2 -bit palettization for weights with more than 300 but less than 1000 elements, create a config as follows:

custom_config = { nn.Conv2d: [ {"n_bits": 4, "cluster_dim": 4, "weight_threshold": 1000}, {"n_bits": 2, "cluster_dim": 2, "weight_threshold": 300}, ] } config = DKMPalettizerConfig.from_dict({"module_type_configs": custom_config})

Parameters:

as_dict() → Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) → DKMPalettizerConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) → DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a strpath to the yaml file.

set_global(global_config: ModuleOptimizationConfig | None) → OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: ModuleOptimizationConfig | None) → OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Callable | str, opt_config: ModuleOptimizationConfig | None) → OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.palettization.DKMPalettizer(model: Module, config: DKMPalettizerConfig | None = None)[source]

A palettization algorithm based on “DKM: Differentiable K-Means Clustering Layer for Neural Network Compression”. It clusters the weights using a differentiable version of k-means, allowing the lookup table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD.

Example

import torch from coremltools.optimize.torch.palettization import ( DKMPalettizer, DKMPalettizerConfig, ModuleDKMPalettizerConfig, )

code that defines the pytorch model, loss and optimizer.

model, loss_fn, optimizer = create_model_loss_and_optimizer()

initialize the palettizer

config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig(n_bits=4))

palettizer = DKMPalettizer(model, config)

prepare the model to insert FakePalettize layers for palettization

model = palettizer.prepare(inplace=True)

use palettizer in your PyTorch training loop

for inputs, labels in data: output = model(inputs) loss = loss_fn(output, labels) loss.backward() optimizer.step() palettizer.step()

fold LUT and indices into weights

model = palettizer.finalize(inplace=True)

Parameters:

finalize(model: Module | None = None, inplace: bool = False) → Module[source]

Removes FakePalettize layers from a model and creates new model weights from the LUT andindices buffers.

This function is called to prepare a palettized model for export usingcoremltools.

Parameters:

prepare(inplace: bool = False) → Module[source]

Prepares a model for palettization aware training by inserting FakePalettize layers in appropriate places as specified by the config.

Parameters:

inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

report() → _Report[source]

Returns a dictionary with important statistics related to current state of palettization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics, such as number of clusters and cluster dimension, number of parameters, and so on.

step()[source]

Step through the palettizer. When the number of times stepis called is equal to milestone, palettization is enabled.

class coremltools.optimize.torch.palettization.ModulePostTrainingPalettizerConfig(n_bits: int | None = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool | None = False, enable_fast_kmeans_mode: bool | None = True, rounding_precision: int | None = 4)[source]

Configuration class for specifying global and module-level palettization options forPostTrainingPalettizerConfig algorithm.

Parameters:

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor, and group_size is None.

2. Per-grouped-channel palettization: In this configuration, the number of channels group_size alongchannel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we providegroup_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along either the input or output channel axis.

class coremltools.optimize.torch.palettization.PostTrainingPalettizer(model: Module, config: PostTrainingPalettizerConfig | None = None)[source]

Perform post-training palettization on a torch model. Post palettization, all the weights in supported layers point to elements in a lookup table after performing a k-means operation.

Example

import torch.nn as nn from coremltools.optimize.torch.palettization import ( PostTrainingPalettizerConfig, PostTrainingPalettizer, )

model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) )

initialize the palettizer

config = PostTrainingPalettizerConfig.from_dict( { "global_config": { "n_bits": 4, }, } )

ptpalettizer = PostTrainingPalettizer(model, config) palettized_model = ptpalettizer.compress()

Parameters:

class coremltools.optimize.torch.palettization.PostTrainingPalettizerConfig(global_config: ModulePostTrainingPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModulePostTrainingPalettizerConfig | None] = NOTHING)[source]

Configuration class for specifying how different submodules of a model should be post-training palettized by PostTrainingPalettizer.

Parameters:

class coremltools.optimize.torch.palettization.ModuleSKMPalettizerConfig(n_bits: int = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool = False)[source]

Configuration class for specifying global and module-level palettization options forSKMPalettizer algorithm.

Parameters:

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor, and group_size is None.

2. Per-grouped-channel palettization: In this configuration, the number of channels group_size alongchannel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we providegroup_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along either the input or output channel axis.

class coremltools.optimize.torch.palettization.SKMPalettizer(model: Module, config: SKMPalettizerConfig | None = None)[source]

Perform post-training palettization of weights by running a weighted k-means on the model weights. The weight values used for weighing different elements of a model’s weight matrix are computed using the Fisher information matrix, which is an approximation of the Hessian. These weight values indicate how sensitive a given weight element is: the more sensitive an element, the larger the impact perturbing or palettizing it has on the model’s loss function. This means that weighted k-means moves the clusters closer to the sensitive weight values, allowing them to be represented more exactly. This leads to a lower degradation in model performance after palettization. The Fisher information matrix is computed using a few samples of calibration data.

This algorithm implements SqueezeLLM: Dense-and-Sparse Quantization.

Example

import torch.nn as nn from coremltools.optimize.torch.palettization import ( SKMPalettizer, SKMPalettizerConfig, )

model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) )

dataloader = load_calibration_data()

define callable for loss function

def loss_fn(model, data): inp, target = data out = model(inp) return nn.functional.mse_loss(out, target)

initialize the palettizer

config = SKMPalettizerConfig.from_dict( { "global_config": { "n_bits": 4, }, "calibration_nsamples": 16, } )

compressor = SKMPalettizer(model, config) compressed_model = compressor.compress(dataloader=dataloader, loss_fn=loss_fn)

Parameters:

class coremltools.optimize.torch.palettization.SKMPalettizerConfig(global_config: ModuleSKMPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModuleSKMPalettizerConfig | None] = NOTHING, calibration_nsamples: int = 128)[source]

Configuration class for specifying how different submodules of a model are palettized by SKMPalettizer.

Parameters: