Sparsity and cluster preserving quantization aware training (PCQAT) Keras example (original) (raw)

Overview

This is an end to end example showing the usage of the sparsity and cluster preserving quantization aware training (PCQAT) API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline.

Other pages

For an introduction to the pipeline and other available techniques, see the collaborative optimization overview page.

Contents

In the tutorial, you will:

  1. Train a keras model for the MNIST dataset from scratch.
  2. Fine-tune the model with pruning and see the accuracy and observe that the model was successfully pruned.
  3. Apply sparsity preserving clustering on the pruned model and observe that the sparsity applied earlier has been preserved.
  4. Apply QAT and observe the loss of sparsity and clusters.
  5. Apply PCQAT and observe that both sparsity and clustering applied earlier have been preserved.
  6. Generate a TFLite model and observe the effects of applying PCQAT on it.
  7. Compare the sizes of the different models to observe the compression benefits of applying sparsity followed by the collaborative optimization techniques of sparsity preserving clustering and PCQAT.
  8. Compare the accurracy of the fully optimized model with the un-optimized baseline model accuracy.

Setup

You can run this Jupyter Notebook in your local virtualenv or colab. For details of setting up dependencies, please refer to the installation guide.

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

2025-05-01 11:23:45.302661: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1746098625.326298 16919 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1746098625.333578 16919 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1746098625.351603 16919 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098625.351629 16919 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098625.351632 16919 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098625.351634 16919 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Train a keras model for MNIST to be pruned and clustered

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

opt = keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

2025-05-01 11:23:49.127819: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.3089 - accuracy: 0.9126 - val_loss: 0.1210 - val_accuracy: 0.9692 Epoch 2/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.1112 - accuracy: 0.9688 - val_loss: 0.0787 - val_accuracy: 0.9798 Epoch 3/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0779 - accuracy: 0.9779 - val_loss: 0.0679 - val_accuracy: 0.9820 Epoch 4/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0637 - accuracy: 0.9818 - val_loss: 0.0599 - val_accuracy: 0.9840 Epoch 5/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0547 - accuracy: 0.9833 - val_loss: 0.0599 - val_accuracy: 0.9843 Epoch 6/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0480 - accuracy: 0.9856 - val_loss: 0.0562 - val_accuracy: 0.9858 Epoch 7/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0430 - accuracy: 0.9873 - val_loss: 0.0611 - val_accuracy: 0.9840 Epoch 8/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0381 - accuracy: 0.9888 - val_loss: 0.0605 - val_accuracy: 0.9845 Epoch 9/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0349 - accuracy: 0.9892 - val_loss: 0.0531 - val_accuracy: 0.9857 Epoch 10/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0307 - accuracy: 0.9909 - val_loss: 0.0579 - val_accuracy: 0.9848 <tf_keras.src.callbacks.History at 0x7f0529d3da90>

Evaluate the baseline model and save it for later usage

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.983299970626831 Saving model to: /tmpfs/tmp/tmpm3g0coct.h5 /tmpfs/tmp/ipykernel_16919/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via model.save(). This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras'). keras.models.save_model(model, keras_file, include_optimizer=False)

Prune and fine-tune the model to 50% sparsity

Apply the prune_low_magnitude() API to achieve the pruned model that is to be clustered in the next step. Refer to the pruning comprehensive guide for more information on the pruning API.

Define the model and apply the sparsity API

Note that the pre-trained model is used.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

Fine-tune the model, check sparsity, and evaluate the accuracy against baseline

Fine-tune the model with pruning for 3 epochs.

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)

Epoch 1/3 1688/1688 [==============================] - 9s 4ms/step - loss: 0.1033 - accuracy: 0.9655 - val_loss: 0.1024 - val_accuracy: 0.9680 Epoch 2/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0706 - accuracy: 0.9774 - val_loss: 0.0840 - val_accuracy: 0.9748 Epoch 3/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0572 - accuracy: 0.9825 - val_loss: 0.0770 - val_accuracy: 0.9783 <tf_keras.src.callbacks.History at 0x7f047a35da90>

Define helper functions to calculate and print the sparsity and clusters of the model.

def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Let's strip the pruning wrapper first, then check that the model kernels were correctly pruned.

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

conv2d/kernel:0: 50.00% sparsity (54/108) dense/kernel:0: 50.00% sparsity (10140/20280)

Apply sparsity preserving clustering and check its effect on model sparsity in both cases

Next, apply sparsity preserving clustering on the pruned model and observe the number of clusters and check that the sparsity is preserved.

import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)

Train sparsity preserving clustering model: Epoch 1/3 1688/1688 [==============================] - 8s 4ms/step - loss: 0.0443 - accuracy: 0.9862 - val_loss: 0.0655 - val_accuracy: 0.9830 Epoch 2/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0431 - accuracy: 0.9864 - val_loss: 0.0669 - val_accuracy: 0.9835 Epoch 3/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0424 - accuracy: 0.9854 - val_loss: 0.0635 - val_accuracy: 0.9837 <tf_keras.src.callbacks.History at 0x7f047a025940>

Strip the clustering wrapper first, then check that the model is correctly pruned and clustered.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)

Model sparsity:

kernel:0: 50.93% sparsity (55/108) kernel:0: 59.39% sparsity (12045/20280)

Model clusters:

conv2d/kernel:0: 8 clusters dense/kernel:0: 8 clusters

Apply QAT and PCQAT and check effect on model clusters and sparsity

Next, apply both QAT and PCQAT on the sparse clustered model and observe that PCQAT preserves weight sparsity and clusters in your model. Note that the stripped model is passed to the QAT and PCQAT API.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
pcqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

pcqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

Train qat model: 422/422 [==============================] - 3s 6ms/step - loss: 0.0304 - accuracy: 0.9909 - val_loss: 0.0570 - val_accuracy: 0.9852 Train pcqat model: WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using model.compile(), did you forget to provide a loss argument? 422/422 [==============================] - 4s 6ms/step - loss: 0.0323 - accuracy: 0.9899 - val_loss: 0.0610 - val_accuracy: 0.9833 <tf_keras.src.callbacks.History at 0x7f047a4dba30>

print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)

QAT Model clusters: quant_conv2d/conv2d/kernel:0: 100 clusters quant_dense/dense/kernel:0: 18225 clusters

QAT Model sparsity: conv2d/kernel:0: 8.33% sparsity (9/108) dense/kernel:0: 7.24% sparsity (1469/20280)

PCQAT Model clusters: quant_conv2d/conv2d/kernel:0: 8 clusters quant_dense/dense/kernel:0: 8 clusters

PCQAT Model sparsity: conv2d/kernel:0: 50.93% sparsity (55/108) dense/kernel:0: 59.40% sparsity (12046/20280)

See compression benefits of PCQAT model

Define helper function to get zipped model file.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Observe that applying sparsity, clustering and PCQAT to a model yields significant compression benefits.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpohodwd5l/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpohodwd5l/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1746098751.476942 16919 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746098751.476979 16919 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. I0000 00:00:1746098751.497052 16919 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled INFO:tensorflow:Assets written to: /tmpfs/tmp/tmprnmfjkdx/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmprnmfjkdx/assets QAT model size: 13.81 KB PCQAT model size: 7.541 KB /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( W0000 00:00:1746098753.700116 16919 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746098753.700142 16919 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.

See the persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TFLite model on the test dataset.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Evaluate the model, which has been pruned, clustered and quantized, and then see that the accuracy from TensorFlow persists in the TFLite backend.

interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()

pcqat_test_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)

/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package. See the migration guide for details.

warnings.warn(_INTERPRETER_DELETION_WARNING) INFO: Created TensorFlow Lite XNNPACK delegate for CPU. Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far.

Pruned, clustered and quantized TFLite test_accuracy: 0.9817 Baseline TF test accuracy: 0.983299970626831

Conclusion

In this tutorial, you learned how to create a model, prune it using the prune_low_magnitude() API, and apply sparsity preserving clustering using the cluster_weights() API to preserve sparsity while clustering the weights.

Next, sparsity and cluster preserving quantization aware training (PCQAT) was applied to preserve model sparsity and clusters while using QAT. The final PCQAT model was compared to the QAT one to show that sparsity and clusters are preserved in the former and lost in the latter.

Next, the models were converted to TFLite to show the compression benefits of chaining sparsity, clustering, and PCQAT model optimization techniques and the TFLite model was evaluated to ensure that the accuracy persists in the TFLite backend.

Finally, the PCQAT TFLite model accuracy was compared to the pre-optimization baseline model accuracy to show that collaborative optimization techniques managed to achieve the compression benefits while maintaining a similar accuracy compared to the original model.