Cluster preserving quantization aware training (CQAT) Keras example (original) (raw)

Overview

This is an end to end example showing the usage of the cluster preserving quantization aware training (CQAT) API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline.

Other pages

For an introduction to the pipeline and other available techniques, see the collaborative optimization overview page.

Contents

In the tutorial, you will:

  1. Train a keras model for the MNIST dataset from scratch.
  2. Fine-tune the model with clustering and see the accuracy.
  3. Apply QAT and observe the loss of clusters.
  4. Apply CQAT and observe that the clustering applied earlier has been preserved.
  5. Generate a TFLite model and observe the effects of applying CQAT on it.
  6. Compare the achieved CQAT model accuracy with a model quantized using post-training quantization.

Setup

You can run this Jupyter Notebook in your local virtualenv or colab. For details of setting up dependencies, please refer to the installation guide.

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

2025-05-01 11:19:13.528972: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1746098353.553010 13622 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1746098353.560286 13622 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1746098353.578625 13622 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098353.578662 13622 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098353.578665 13622 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098353.578668 13622 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Train a keras model for MNIST without clustering

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

2025-05-01 11:19:17.620478: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.3373 - accuracy: 0.9049 - val_loss: 0.1517 - val_accuracy: 0.9607 Epoch 2/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.1411 - accuracy: 0.9597 - val_loss: 0.0945 - val_accuracy: 0.9753 Epoch 3/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0963 - accuracy: 0.9729 - val_loss: 0.0777 - val_accuracy: 0.9798 Epoch 4/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0758 - accuracy: 0.9780 - val_loss: 0.0691 - val_accuracy: 0.9820 Epoch 5/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0637 - accuracy: 0.9809 - val_loss: 0.0645 - val_accuracy: 0.9825 Epoch 6/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0560 - accuracy: 0.9834 - val_loss: 0.0613 - val_accuracy: 0.9838 Epoch 7/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0493 - accuracy: 0.9857 - val_loss: 0.0617 - val_accuracy: 0.9840 Epoch 8/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0445 - accuracy: 0.9864 - val_loss: 0.0563 - val_accuracy: 0.9843 Epoch 9/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0400 - accuracy: 0.9876 - val_loss: 0.0662 - val_accuracy: 0.9847 Epoch 10/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0361 - accuracy: 0.9895 - val_loss: 0.0567 - val_accuracy: 0.9837 <tf_keras.src.callbacks.History at 0x7ff8c0b6e940>

Evaluate the baseline model and save it for later usage

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.982699990272522 Saving model to: /tmpfs/tmp/tmpkmnsttu9.h5 /tmpfs/tmp/ipykernel_13622/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via model.save(). This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras'). keras.models.save_model(model, keras_file, include_optimizer=False)

Cluster and fine-tune the model with 8 clusters

Apply the cluster_weights() API to cluster the whole pre-trained model to demonstrate and observe its effectiveness in reducing the model size when applying zip, while maintaining accuracy. For how best to use the API to achieve the best compression rate while maintaining your target accuracy, refer to the clustering comprehensive guide.

Define the model and apply the clustering API

The model needs to be pre-trained before using the clustering API.

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'cluster_per_channel': True,
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()

Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param #

cluster_reshape (ClusterWe (None, 28, 28, 1) 0
ights)

cluster_conv2d (ClusterWei (None, 26, 26, 12) 324
ghts)

cluster_max_pooling2d (Clu (None, 13, 13, 12) 0
sterWeights)

cluster_flatten (ClusterWe (None, 2028) 0
ights)

cluster_dense (ClusterWeig (None, 10) 40578
hts)

Total params: 40902 (239.41 KB) Trainable params: 20514 (80.13 KB) Non-trainable params: 20388 (159.28 KB)


Fine-tune the model and evaluate the accuracy against baseline

Fine-tune the model with clustering for 3 epochs.

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)

Epoch 1/3 1688/1688 [==============================] - 9s 5ms/step - loss: 0.0331 - accuracy: 0.9904 - val_loss: 0.0580 - val_accuracy: 0.9828 Epoch 2/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0296 - accuracy: 0.9917 - val_loss: 0.0571 - val_accuracy: 0.9842 Epoch 3/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0294 - accuracy: 0.9920 - val_loss: 0.0579 - val_accuracy: 0.9840 <tf_keras.src.callbacks.History at 0x7ff7c98ad280>

Define helper functions to calculate and print the number of clustering in each kernel of the model.

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Check that the model kernels were correctly clustered. We need to strip the clustering wrapper first.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)

conv2d/kernel:0: 96 clusters dense/kernel:0: 8 clusters

For this example, there is minimal loss in test accuracy after clustering, compared to the baseline.

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)

Baseline test accuracy: 0.982699990272522 Clustered test accuracy: 0.983299970626831

Apply QAT and CQAT and check effect on model clusters in both cases

Next, we apply both QAT and cluster preserving QAT (CQAT) on the clustered model and observe that CQAT preserves weight clusters in your clustered model. Note that we stripped clustering wrappers from your model with tfmot.clustering.keras.strip_clustering before applying CQAT API.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

Train qat model: 422/422 [==============================] - 4s 7ms/step - loss: 0.0322 - accuracy: 0.9907 - val_loss: 0.0545 - val_accuracy: 0.9867 WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. Train cqat model: WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? 422/422 [==============================] - 5s 7ms/step - loss: 0.0291 - accuracy: 0.9918 - val_loss: 0.0591 - val_accuracy: 0.9842 <tf_keras.src.callbacks.History at 0x7ff7aee04e50>

print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)

QAT Model clusters: quant_conv2d/conv2d/kernel:0: 108 clusters quant_dense/dense/kernel:0: 19860 clusters CQAT Model clusters: quant_conv2d/conv2d/kernel:0: 96 clusters quant_dense/dense/kernel:0: 8 clusters

See compression benefits of CQAT model

Define helper function to get zipped model file.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Note that this is a small model. Applying clustering and CQAT to a bigger production model would yield a more significant compression.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpz12ayq1l/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpz12ayq1l/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1746098462.294508 13622 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746098462.294551 13622 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. I0000 00:00:1746098462.314870 13622 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4jfflb8a/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4jfflb8a/assets QAT model size: 17.434 KB CQAT model size: 10.807 KB /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( W0000 00:00:1746098464.897716 13622 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746098464.897742 13622 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.

See the persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TFLite model on the test dataset.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

You evaluate the model, which has been clustered and quantized, and then see the accuracy from TensorFlow persists in the TFLite backend.

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)

/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package. See the migration guide for details.

warnings.warn(_INTERPRETER_DELETION_WARNING) INFO: Created TensorFlow Lite XNNPACK delegate for CPU. Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far.

Clustered and quantized TFLite test_accuracy: 0.9821 Clustered TF test accuracy: 0.983299970626831

Apply post-training quantization and compare to CQAT model

Next, we use post-training quantization (no fine-tuning) on the clustered model and check its accuracy against the CQAT model. This demonstrates why you would need to use CQAT to improve the quantized model's accuracy. The difference may not be very visible, because the MNIST model is quite small and overparametrized.

First, define a generator for the callibration dataset from the first 1000 training images.

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

Quantize the model and compare accuracy to the previously acquired CQAT model. Note that the model quantized with fine-tuning achieves higher accuracy.

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpwjzir7ct/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpwjzir7ct/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( W0000 00:00:1746098466.671121 13622 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746098466.671157 13622 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package. See the migration guide for details.

warnings.warn(_INTERPRETER_DELETION_WARNING) Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far.

CQAT TFLite test_accuracy: 0.9821 Post-training (no fine-tuning) TF test accuracy: 0.9832

Conclusion

In this tutorial, you learned how to create a model, cluster it using the cluster_weights() API, and apply the cluster preserving quantization aware training (CQAT) to preserve clusters while using QAT. The final CQAT model was compared to the QAT one to show that the clusters are preserved in the former and lost in the latter. Next, the models were converted to TFLite to show the compression benefits of chaining clustering and CQAT model optimization techniques and the TFLite model was evaluated to ensure that the accuracy persists in the TFLite backend. Finally, the CQAT model was compared to a quantized clustered model achieved using the post-training quantization API to demonstrate the advantage of CQAT in recovering the accuracy loss from normal quantization.