Sparsity preserving clustering Keras example (original) (raw)
Overview
This is an end to end example showing the usage of the sparsity preserving clustering API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline.
Other pages
For an introduction to the pipeline and other available techniques, see the collaborative optimization overview page.
Contents
In the tutorial, you will:
- Train a
keras
model for the MNIST dataset from scratch. - Fine-tune the model with sparsity and see the accuracy and observe that the model was successfully pruned.
- Apply weight clustering to the pruned model and observe the loss of sparsity.
- Apply sparsity preserving clustering on the pruned model and observe that the sparsity applied earlier has been preserved.
- Generate a TFLite model and check that the accuracy has been preserved in the pruned clustered model.
- Compare the sizes of the different models to observe the compression benefits of applying sparsity followed by the collaborative optimization technique of sparsity preserving clustering.
Setup
You can run this Jupyter Notebook in your local virtualenv or colab. For details of setting up dependencies, please refer to the installation guide.
pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras
import numpy as np
import tempfile
import zipfile
import os
2025-05-01 11:28:32.346983: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1746098912.370931 20257 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1746098912.378259 20257 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1746098912.396350 20257 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098912.396376 20257 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098912.396378 20257 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746098912.396381 20257 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Train a keras model for MNIST to be pruned and clustered
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
2025-05-01 11:28:36.191761: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.3041 - accuracy: 0.9141 - val_loss: 0.1204 - val_accuracy: 0.9682 Epoch 2/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.1212 - accuracy: 0.9652 - val_loss: 0.0824 - val_accuracy: 0.9792 Epoch 3/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0829 - accuracy: 0.9755 - val_loss: 0.0712 - val_accuracy: 0.9810 Epoch 4/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0670 - accuracy: 0.9804 - val_loss: 0.0611 - val_accuracy: 0.9852 Epoch 5/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0576 - accuracy: 0.9831 - val_loss: 0.0614 - val_accuracy: 0.9845 Epoch 6/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0506 - accuracy: 0.9845 - val_loss: 0.0604 - val_accuracy: 0.9838 Epoch 7/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0452 - accuracy: 0.9863 - val_loss: 0.0612 - val_accuracy: 0.9852 Epoch 8/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0412 - accuracy: 0.9875 - val_loss: 0.0582 - val_accuracy: 0.9852 Epoch 9/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0377 - accuracy: 0.9883 - val_loss: 0.0571 - val_accuracy: 0.9850 Epoch 10/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0343 - accuracy: 0.9895 - val_loss: 0.0580 - val_accuracy: 0.9855 <tf_keras.src.callbacks.History at 0x7ff5306cb940>
Evaluate the baseline model and save it for later usage
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.983299970626831
Saving model to: /tmpfs/tmp/tmpmm2gdo8k.h5
/tmpfs/tmp/ipykernel_20257/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via model.save()
. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras')
.
keras.models.save_model(model, keras_file, include_optimizer=False)
Prune and fine-tune the model to 50% sparsity
Apply the prune_low_magnitude()
API to prune the whole pre-trained model to achieve the model that is to be clustered in the next step. For how best to use the API to achieve the best compression rate while maintaining your target accuracy, refer to the pruning comprehensive guide.
Define the model and apply the sparsity API
Note that the pre-trained model is used.
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
}
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]
pruned_model = prune_low_magnitude(model, **pruning_params)
# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)
pruned_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
pruned_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param #
prune_low_magnitude_reshap (None, 28, 28, 1) 1
e (PruneLowMagnitude)
prune_low_magnitude_conv2d (None, 26, 26, 12) 230
(PruneLowMagnitude)
prune_low_magnitude_max_po (None, 13, 13, 12) 1
oling2d (PruneLowMagnitude
)
prune_low_magnitude_flatte (None, 2028) 1
n (PruneLowMagnitude)
prune_low_magnitude_dense (None, 10) 40572
(PruneLowMagnitude)
Total params: 40805 (159.41 KB) Trainable params: 20410 (79.73 KB) Non-trainable params: 20395 (79.69 KB)
Fine-tune the model, check sparsity, and evaluate the accuracy against baseline
Fine-tune the model with pruning for 3 epochs.
# Fine-tune model
pruned_model.fit(
train_images,
train_labels,
epochs=3,
validation_split=0.1,
callbacks=callbacks)
Epoch 1/3 1688/1688 [==============================] - 9s 4ms/step - loss: 0.0537 - accuracy: 0.9822 - val_loss: 0.0692 - val_accuracy: 0.9805 Epoch 2/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0445 - accuracy: 0.9862 - val_loss: 0.0646 - val_accuracy: 0.9823 Epoch 3/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0415 - accuracy: 0.9876 - val_loss: 0.0628 - val_accuracy: 0.9823 <tf_keras.src.callbacks.History at 0x7ff48d97d340>
Define helper functions to calculate and print the sparsity of the model.
def print_model_weights_sparsity(model):
for layer in model.layers:
if isinstance(layer, keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
if "kernel" not in weight.name or "centroid" in weight.name:
continue
weight_size = weight.numpy().size
zero_num = np.count_nonzero(weight == 0)
print(
f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
f"({zero_num}/{weight_size})",
)
Check that the model kernels was correctly pruned. We need to strip the pruning wrapper first. We also create a deep copy of the model to be used in the next step.
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
print_model_weights_sparsity(stripped_pruned_model)
stripped_pruned_model_copy = keras.models.clone_model(stripped_pruned_model)
stripped_pruned_model_copy.set_weights(stripped_pruned_model.get_weights())
conv2d/kernel:0: 50.00% sparsity (54/108) dense/kernel:0: 50.00% sparsity (10140/20280)
Apply clustering and sparsity preserving clustering and check its effect on model sparsity in both cases
Next, we apply both clustering and sparsity preserving clustering on the pruned model and observe that the latter preserves sparsity on your pruned model. Note that we stripped pruning wrappers from the pruned model with tfmot.sparsity.keras.strip_pruning before applying the clustering API.
# Clustering
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}
clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)
clustered_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train clustering model:')
clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
stripped_pruned_model.save("stripped_pruned_model_clustered.h5")
# Sparsity preserving clustering
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
cluster,
)
cluster_weights = cluster.cluster_weights
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
'preserve_sparsity': True
}
sparsity_clustered_model = cluster_weights(stripped_pruned_model_copy, **clustering_params)
sparsity_clustered_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
Train clustering model:
Epoch 1/3
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0398 - accuracy: 0.9876 - val_loss: 0.0634 - val_accuracy: 0.9843
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0392 - accuracy: 0.9876 - val_loss: 0.0674 - val_accuracy: 0.9808
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0406 - accuracy: 0.9864 - val_loss: 0.0633 - val_accuracy: 0.9845
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. model.compile_metrics
will be empty until you train or evaluate the model.
Train sparsity preserving clustering model:
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via model.save()
. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras')
.
saving_api.save_model(
Epoch 1/3
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0393 - accuracy: 0.9885 - val_loss: 0.0565 - val_accuracy: 0.9848
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0380 - accuracy: 0.9881 - val_loss: 0.0665 - val_accuracy: 0.9827
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0393 - accuracy: 0.9871 - val_loss: 0.0621 - val_accuracy: 0.9840
<tf_keras.src.callbacks.History at 0x7ff4707466a0>
Check sparsity for both models.
print("Clustered Model sparsity:\n")
print_model_weights_sparsity(clustered_model)
print("\nSparsity preserved clustered Model sparsity:\n")
print_model_weights_sparsity(sparsity_clustered_model)
Clustered Model sparsity:
conv2d/kernel:0: 0.00% sparsity (0/108) dense/kernel:0: 1.24% sparsity (251/20280)
Sparsity preserved clustered Model sparsity:
conv2d/kernel:0: 50.00% sparsity (54/108) dense/kernel:0: 50.00% sparsity (10140/20280)
Create 1.6x smaller models from clustering
Define helper function to get zipped model file.
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in kilobytes.
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)/1000
# Clustered model
clustered_model_file = 'clustered_model.h5'
# Save the model.
clustered_model.save(clustered_model_file)
#Sparsity Preserve Clustered model
sparsity_clustered_model_file = 'sparsity_clustered_model.h5'
# Save the model.
sparsity_clustered_model.save(sparsity_clustered_model_file)
print("Clustered Model size: ", get_gzipped_model_size(clustered_model_file), ' KB')
print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
Clustered Model size: 245.092 KB Sparsity preserved clustered Model size: 141.373 KB
Create a TFLite model from combining sparsity preserving weight clustering and post-training quantization
Strip clustering wrappers and convert to TFLite.
stripped_sparsity_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_sparsity_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
sparsity_clustered_quant_model = converter.convert()
_, pruned_and_clustered_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_and_clustered_tflite_file, 'wb') as f:
f.write(sparsity_clustered_quant_model)
print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
print("Sparsity preserved clustered and quantized TFLite model size:",
get_gzipped_model_size(pruned_and_clustered_tflite_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpw1x0ch9y/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpw1x0ch9y/assets Sparsity preserved clustered Model size: 141.373 KB Sparsity preserved clustered and quantized TFLite model size: 7.847 KB WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1746099050.558110 20257 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746099050.558148 20257 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. I0000 00:00:1746099050.563752 20257 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
See the persistence of accuracy from TF to TFLite
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print(f"Evaluated on {i} results so far.")
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
You evaluate the model, which has been pruned, clustered and quantized, and then see that the accuracy from TensorFlow persists in the TFLite backend.
# Keras model evaluation
stripped_sparsity_clustered_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
_, sparsity_clustered_keras_accuracy = stripped_sparsity_clustered_model.evaluate(
test_images, test_labels, verbose=0)
# TFLite model evaluation
interpreter = tf.lite.Interpreter(pruned_and_clustered_tflite_file)
interpreter.allocate_tensors()
sparsity_clustered_tflite_accuracy = eval_model(interpreter)
print('Pruned, clustered and quantized Keras model accuracy:', sparsity_clustered_keras_accuracy)
print('Pruned, clustered and quantized TFLite model accuracy:', sparsity_clustered_tflite_accuracy)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package. See the migration guide for details.
warnings.warn(_INTERPRETER_DELETION_WARNING) INFO: Created TensorFlow Lite XNNPACK delegate for CPU. Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far.
Pruned, clustered and quantized Keras model accuracy: 0.9797999858856201 Pruned, clustered and quantized TFLite model accuracy: 0.98
Conclusion
In this tutorial, you learned how to create a model, prune it using the prune_low_magnitude()
API, and apply sparsity preserving clustering to preserve sparsity while clustering the weights. The sparsity preserving clustered model was compared to a clustered one to show that sparsity is preserved in the former and lost in the latter. Next, the pruned clustered model was converted to TFLite to show the compression benefits of chaining the pruning and sparsity preserving clustering model optimization techniques and, finally, the TFLite model was evaluated to ensure that the accuracy persists in the TFLite backend.