Quantization aware training comprehensive guide (original) (raw)
Welcome to the comprehensive guide for Keras quantization aware training.
This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in theAPI docs.
- If you want to see the benefits of quantization aware training and what's supported, see the overview.
- For a single end-to-end example, see the quantization aware training example.
The following use cases are covered:
- Deploy a model with 8-bit quantization with these steps.
- Define a quantization aware model.
- For Keras HDF5 models only, use special checkpointing and deserialization logic. Training is otherwise standard.
- Create a quantized model from the quantization aware one.
- Experiment with quantization.
- Anything for experimentation has no supported path to deployment.
- Custom Keras layers fall under experimentation.
Setup
For finding the APIs you need and understanding purposes, you can run but skip reading this section.
! pip install -q tensorflow
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tf_keras as keras
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = keras.Sequential([
keras.layers.Dense(20, input_shape=input_shape),
keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model= setup_model()
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def setup_pretrained_model():
model = setup_model()
pretrained_weights = setup_pretrained_weights()
model.load_weights(pretrained_weights)
return model
setup_model()
pretrained_weights = setup_pretrained_weights()
2025-05-01 11:54:21.538294: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1746100461.561915 42270 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1746100461.569129 42270 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1746100461.587239 42270 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746100461.587268 42270 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746100461.587270 42270 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746100461.587273 42270 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. 2025-05-01 11:54:24.980724: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Define quantization aware model
By defining models in the following ways, there are available paths to deployment to backends listed in the overview page. By default, 8-bit quantization is used.
Quantize whole model
Your use case:
- Subclassed models are not supported.
Tips for better model accuracy:
- Try "Quantize some layers" to skip quantizing the layers that reduce accuracy the most.
- It's generally better to finetune with quantization aware training as opposed to training from scratch.
To make the whole model aware of quantization, apply tfmot.quantization.keras.quantize_model to the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param #
quantize_layer (QuantizeLa (None, 20) 3
yer)
quant_dense_2 (QuantizeWra (None, 20) 425
pperV2)
quant_flatten_2 (QuantizeW (None, 20) 1
rapperV2)
Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte)
Quantize some layers
Quantizing a model can have a negative effect on accuracy. You can selectively quantize layers of a model to explore the trade-off between accuracy, speed, and model size.
Your use case:
- To deploy to a backend that only works well with fully quantized models (e.g. EdgeTPU v1, most DSPs), try "Quantize whole model".
Tips for better model accuracy:
- It's generally better to finetune with quantization aware training as opposed to training from scratch.
- Try quantizing the later layers instead of the first layers.
- Avoid quantizing critical layers (e.g. attention mechanism).
In the example below, quantize only the Dense
layers.
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `quantize_annotate_layer` to annotate that only the
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
if isinstance(layer, keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
# Use `keras.models.clone_model` to apply `apply_quantization_to_dense`
# to the layers of the model.
annotated_model = keras.models.clone_model(
base_model,
clone_function=apply_quantization_to_dense,
)
# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use status.expect_partial()
. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use status.expect_partial()
. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
quantize_layer_1 (Quantize (None, 20) 3
Layer)
quant_dense_3 (QuantizeWra (None, 20) 425
pperV2)
flatten_3 (Flatten) (None, 20) 0
Total params: 428 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 8 (32.00 Byte)
While this example used the type of the layer to decide what to quantize, the easiest way to quantize a particular layer is to set its name
property, and look for that name in the clone_function
.
print(base_model.layers[0].name)
dense_3
More readable but potentially lower model accuracy
This is not compatible with finetuning with quantization aware training, which is why it may be less accurate than the above examples.
Functional example
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(10))(i)
o = keras.layers.Flatten()(x)
annotated_model = keras.Model(inputs=i, outputs=o)
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param #
input_1 (InputLayer) [(None, 20)] 0
quantize_layer_2 (Quantize (None, 20) 3
Layer)
quant_dense_4 (QuantizeWra (None, 10) 215
pperV2)
flatten_4 (Flatten) (None, 10) 0
Total params: 218 (872.00 Byte) Trainable params: 210 (840.00 Byte) Non-trainable params: 8 (32.00 Byte)
Sequential example
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = keras.Sequential([
tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(20, input_shape=input_shape)),
keras.layers.Flatten()
])
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param #
quantize_layer_3 (Quantize (None, 20) 3
Layer)
quant_dense_5 (QuantizeWra (None, 20) 425
pperV2)
flatten_5 (Flatten) (None, 20) 0
Total params: 428 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 8 (32.00 Byte)
Checkpoint and deserialize
Your use case: this code is only needed for the HDF5 model format (not HDF5 weights or other formats).
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)
# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
loaded_model = keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use status.expect_partial()
. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use status.expect_partial()
. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. model.compile_metrics
will be empty until you train or evaluate the model.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via model.save()
. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras')
.
saving_api.save_model(
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. model.compile_metrics
will be empty until you train or evaluate the model.
WARNING:tensorflow:No training configuration found in the save file, so the model was not compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in the save file, so the model was not compiled. Compile it manually.
Model: "sequential_5"
_________________________________________________________________
Layer (type) Output Shape Param #
quantize_layer_4 (Quantize (None, 20) 3
Layer)
quant_dense_6 (QuantizeWra (None, 20) 425
pperV2)
quant_flatten_6 (QuantizeW (None, 20) 1
rapperV2)
Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte)
Create and deploy quantized model
In general, reference the documentation for the deployment backend that you will use.
This is an example for the TFLite backend.
base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Typically you train the model here.
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
1/1 [==============================] - 0s 328ms/step - loss: 1.9510 - accuracy: 0.0000e+00
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use status.expect_partial()
. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use status.expect_partial()
. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp91qma29d/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp91qma29d/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1746100468.902917 42270 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1746100468.902954 42270 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
I0000 00:00:1746100468.914089 42270 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
Experiment with quantization
Your use case: using the following APIs means that there is no supported path to deployment. For instance, TFLite conversion and kernel implementations only support 8-bit quantization. The features are also experimental and not subject to backward compatibility.
- tfmot.quantization.keras.QuantizeConfig
- tfmot.quantization.keras.quantizers.Quantizer
- tfmot.quantization.keras.quantizers.LastValueQuantizer
- tfmot.quantization.keras.quantizers.MovingAverageQuantizer
Setup: DefaultDenseQuantizeConfig
Experimenting requires using tfmot.quantization.keras.QuantizeConfig, which describes how to quantize the weights, activations, and outputs of a layer.
Below is an example that defines the same QuantizeConfig
used for the Dense
layer in the API defaults.
During the forward propagation in this example, the LastValueQuantizer
returned in get_weights_and_quantizers
is called with layer.kernel
as the input, producing an output. The output replaces layer.kernel
in the original forward propagation of the Dense
layer, via the logic defined in set_quantize_weights
. The same idea applies to the activations and outputs.
LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer
class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
# Configure how to quantize weights.
def get_weights_and_quantizers(self, layer):
return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]
# Configure how to quantize activations.
def get_activations_and_quantizers(self, layer):
return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]
def set_quantize_weights(self, layer, quantize_weights):
# Add this line for each item returned in `get_weights_and_quantizers`
# , in the same order
layer.kernel = quantize_weights[0]
def set_quantize_activations(self, layer, quantize_activations):
# Add this line for each item returned in `get_activations_and_quantizers`
# , in the same order.
layer.activation = quantize_activations[0]
# Configure how to quantize outputs (may be equivalent to activations).
def get_output_quantizers(self, layer):
return []
def get_config(self):
return {}
Quantize custom Keras layer
This example uses the DefaultDenseQuantizeConfig
to quantize the CustomLayer
.
Applying the configuration is the same across the "Experiment with quantization" use cases.
- Apply tfmot.quantization.keras.quantize_annotate_layer to the
CustomLayer
and pass in theQuantizeConfig
. - Usetfmot.quantization.keras.quantize_annotate_model to continue to quantize the rest of the model with the API defaults.
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class CustomLayer(keras.layers.Dense):
pass
model = quantize_annotate_model(keras.Sequential([
quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
{'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
'CustomLayer': CustomLayer}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param #
quantize_layer_6 (Quantize (None, 20) 3
Layer)
quant_custom_layer (Quanti (None, 20) 425
zeWrapperV2)
quant_flatten_9 (QuantizeW (None, 20) 1
rapperV2)
Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte)
Modify quantization parameters
Common mistake: quantizing the bias to fewer than 32-bits usually harms model accuracy too much.
This example modifies the Dense
layer to use 4-bits for its weights instead of the default 8-bits. The rest of the model continues to use API defaults.
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
# Configure weights to quantize with 4-bit instead of 8-bits.
def get_weights_and_quantizers(self, layer):
return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]
Applying the configuration is the same across the "Experiment with quantization" use cases.
- Apply tfmot.quantization.keras.quantize_annotate_layer to the
Dense
layer and pass in theQuantizeConfig
. - Usetfmot.quantization.keras.quantize_annotate_model to continue to quantize the rest of the model with the API defaults.
model = quantize_annotate_model(keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this Dense layer.
quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_9" _________________________________________________________________ Layer (type) Output Shape Param #
quantize_layer_7 (Quantize (None, 20) 3
Layer)
quant_dense_9 (QuantizeWra (None, 20) 425
pperV2)
quant_flatten_10 (Quantize (None, 20) 1
WrapperV2)
Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte)
Modify parts of layer to quantize
This example modifies the Dense
layer to skip quantizing the activation. The rest of the model continues to use API defaults.
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
def get_activations_and_quantizers(self, layer):
# Skip quantizing activations.
return []
def set_quantize_activations(self, layer, quantize_activations):
# Empty since `get_activaations_and_quantizers` returns
# an empty list.
return
Applying the configuration is the same across the "Experiment with quantization" use cases.
- Apply tfmot.quantization.keras.quantize_annotate_layer to the
Dense
layer and pass in theQuantizeConfig
. - Usetfmot.quantization.keras.quantize_annotate_model to continue to quantize the rest of the model with the API defaults.
model = quantize_annotate_model(keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this Dense layer.
quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_10" _________________________________________________________________ Layer (type) Output Shape Param #
quantize_layer_8 (Quantize (None, 20) 3
Layer)
quant_dense_10 (QuantizeWr (None, 20) 423
apperV2)
quant_flatten_11 (Quantize (None, 20) 1
WrapperV2)
Total params: 427 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 7 (28.00 Byte)
Use custom quantization algorithm
The tfmot.quantization.keras.quantizers.Quantizer class is a callable that can apply any algorithm to its inputs.
In this example, the inputs are the weights, and we apply the math in theFixedRangeQuantizer
__call__ function to the weights. Instead of the original weights values, the output of theFixedRangeQuantizer
is now passed to whatever would have used the weights.
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
"""Quantizer which forces outputs to be between -1 and 1."""
def build(self, tensor_shape, name, layer):
# Not needed. No new TensorFlow variables needed.
return {}
def __call__(self, inputs, training, weights, **kwargs):
return keras.backend.clip(inputs, -1.0, 1.0)
def get_config(self):
# Not needed. No __init__ parameters to serialize.
return {}
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
# Configure weights to quantize with 4-bit instead of 8-bits.
def get_weights_and_quantizers(self, layer):
# Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
return [(layer.kernel, FixedRangeQuantizer())]
Applying the configuration is the same across the "Experiment with quantization" use cases.
- Apply tfmot.quantization.keras.quantize_annotate_layer to the
Dense
layer and pass in theQuantizeConfig
. - Usetfmot.quantization.keras.quantize_annotate_model to continue to quantize the rest of the model with the API defaults.
model = quantize_annotate_model(keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this `Dense` layer.
quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_11" _________________________________________________________________ Layer (type) Output Shape Param #
quantize_layer_9 (Quantize (None, 20) 3
Layer)
quant_dense_11 (QuantizeWr (None, 20) 423
apperV2)
quant_flatten_12 (Quantize (None, 20) 1
WrapperV2)
Total params: 427 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 7 (28.00 Byte)