Weight clustering in Keras example (original) (raw)
Overview
Welcome to the end-to-end example for weight clustering, part of the TensorFlow Model Optimization Toolkit.
Other pages
For an introduction to what weight clustering is and to determine if you should use it (including what's supported), see the overview page.
To quickly find the APIs you need for your use case (beyond fully clustering a model with 16 clusters), see the comprehensive guide.
Contents
In the tutorial, you will:
- Train a
keras
model for the MNIST dataset from scratch. - Fine-tune the model by applying the weight clustering API and see the accuracy.
- Create a 6x smaller TF and TFLite models from clustering.
- Create a 8x smaller TFLite model from combining weight clustering and post-training quantization.
- See the persistence of accuracy from TF to TFLite.
Setup
You can run this Jupyter Notebook in your local virtualenv or colab. For details of setting up dependencies, please refer to the installation guide.
pip install -q tensorflow-model-optimization
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras
import numpy as np
import tempfile
import zipfile
import os
2025-05-01 11:36:07.385798: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1746099367.409573 26159 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1746099367.416883 26159 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1746099367.435093 26159 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746099367.435119 26159 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746099367.435122 26159 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746099367.435125 26159 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Train a keras model for MNIST without clustering
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
2025-05-01 11:36:11.338364: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.3055 - accuracy: 0.9125 - val_loss: 0.1378 - val_accuracy: 0.9657 Epoch 2/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.1328 - accuracy: 0.9621 - val_loss: 0.0960 - val_accuracy: 0.9732 Epoch 3/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0952 - accuracy: 0.9729 - val_loss: 0.0854 - val_accuracy: 0.9758 Epoch 4/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0761 - accuracy: 0.9774 - val_loss: 0.0781 - val_accuracy: 0.9772 Epoch 5/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0647 - accuracy: 0.9810 - val_loss: 0.0633 - val_accuracy: 0.9827 Epoch 6/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0555 - accuracy: 0.9837 - val_loss: 0.0632 - val_accuracy: 0.9840 Epoch 7/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0487 - accuracy: 0.9854 - val_loss: 0.0647 - val_accuracy: 0.9835 Epoch 8/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0441 - accuracy: 0.9862 - val_loss: 0.0598 - val_accuracy: 0.9848 Epoch 9/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0398 - accuracy: 0.9879 - val_loss: 0.0617 - val_accuracy: 0.9832 Epoch 10/10 1688/1688 [==============================] - 6s 4ms/step - loss: 0.0357 - accuracy: 0.9890 - val_loss: 0.0586 - val_accuracy: 0.9848 <tf_keras.src.callbacks.History at 0x7f0232f2cdc0>
Evaluate the baseline model and save it for later usage
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9812999963760376
Saving model to: /tmpfs/tmp/tmp5523en8x.h5
/tmpfs/tmp/ipykernel_26159/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via model.save()
. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras')
.
keras.models.save_model(model, keras_file, include_optimizer=False)
Fine-tune the pre-trained model with clustering
Apply the cluster_weights()
API to a whole pre-trained model to demonstrate its effectiveness in reducing the model size after applying zip while keeping decent accuracy. For how best to balance the accuracy and compression rate for your use case, please refer to the per layer example in the comprehensive guide.
Define the model and apply the clustering API
Before you pass the model to the clustering API, make sure it is trained and shows some acceptable accuracy.
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 16,
'cluster_centroids_init': CentroidInitialization.LINEAR
}
# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)
# Use smaller learning rate for fine-tuning clustered model
opt = keras.optimizers.Adam(learning_rate=1e-5)
clustered_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
clustered_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param #
cluster_reshape (ClusterWe (None, 28, 28, 1) 0
ights)
cluster_conv2d (ClusterWei (None, 26, 26, 12) 244
ghts)
cluster_max_pooling2d (Clu (None, 13, 13, 12) 0
sterWeights)
cluster_flatten (ClusterWe (None, 2028) 0
ights)
cluster_dense (ClusterWeig (None, 10) 40586
hts)
Total params: 40830 (239.13 KB) Trainable params: 20442 (79.85 KB) Non-trainable params: 20388 (159.28 KB)
Fine-tune the model and evaluate the accuracy against baseline
Fine-tune the model with clustering for 1 epoch.
# Fine-tune model
clustered_model.fit(
train_images,
train_labels,
batch_size=500,
epochs=1,
validation_split=0.1)
108/108 [==============================] - 2s 13ms/step - loss: 0.0344 - accuracy: 0.9891 - val_loss: 0.0651 - val_accuracy: 0.9838 <tf_keras.src.callbacks.History at 0x7f01902820a0>
For this example, there is minimal loss in test accuracy after clustering, compared to the baseline.
_, clustered_model_accuracy = clustered_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9812999963760376 Clustered test accuracy: 0.9782999753952026
Create 6x smaller models from clustering
Both strip_clustering
and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering.
First, create a compressible model for TensorFlow. Here, strip_clustering
removes all variables (e.g. tf.Variable for storing the cluster centroids and the indices) that clustering only needs during training, which would otherwise add to model size during inference.
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
_, clustered_keras_file = tempfile.mkstemp('.h5')
print('Saving clustered model to: ', clustered_keras_file)
keras.models.save_model(final_model, clustered_keras_file,
include_optimizer=False)
Saving clustered model to: /tmpfs/tmp/tmpxjfulsu5.h5
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. model.compile_metrics
will be empty until you train or evaluate the model.
/tmpfs/tmp/ipykernel_26159/2668672504.py:5: UserWarning: You are saving your model as an HDF5 file via model.save()
. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. model.save('my_model.keras')
.
keras.models.save_model(final_model, clustered_keras_file,
Then, create compressible models for TFLite. You can convert the clustered model to a format that's runnable on your targeted backend. TensorFlow Lite is an example you can use to deploy to mobile devices.
clustered_tflite_file = '/tmp/clustered_mnist.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_clustered_model = converter.convert()
with open(clustered_tflite_file, 'wb') as f:
f.write(tflite_clustered_model)
print('Saved clustered TFLite model to:', clustered_tflite_file)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpcnxi2_34/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpcnxi2_34/assets Saved clustered TFLite model to: /tmp/clustered_mnist.tflite WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1746099441.037388 26159 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746099441.037427 26159 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. I0000 00:00:1746099441.043020 26159 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
Define a helper function to actually compress the models via gzip and measure the zipped size.
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in bytes.
import os
import zipfile
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)
Compare and see that the models are 6x smaller from clustering
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered Keras model: %.2f bytes" % (get_gzipped_model_size(clustered_keras_file)))
print("Size of gzipped clustered TFlite model: %.2f bytes" % (get_gzipped_model_size(clustered_tflite_file)))
Size of gzipped baseline Keras model: 78130.00 bytes Size of gzipped clustered Keras model: 13330.00 bytes Size of gzipped clustered TFlite model: 12848.00 bytes
Create an 8x smaller TFLite model from combining weight clustering and post-training quantization
You can apply post-training quantization to the clustered model for additional benefits.
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
_, quantized_and_clustered_tflite_file = tempfile.mkstemp('.tflite')
with open(quantized_and_clustered_tflite_file, 'wb') as f:
f.write(tflite_quant_model)
print('Saved quantized and clustered TFLite model to:', quantized_and_clustered_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_clustered_tflite_file)))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9tjc0sr3/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9tjc0sr3/assets W0000 00:00:1746099441.857093 26159 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746099441.857119 26159 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. Saved quantized and clustered TFLite model to: /tmpfs/tmp/tmp3e8eo4se.tflite Size of gzipped baseline Keras model: 78130.00 bytes Size of gzipped clustered and quantized TFlite model: 10848.00 bytes
See the persistence of accuracy from TF to TFLite
Define a helper function to evaluate the TFLite model on the test dataset.
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print('Evaluated on {n} results so far.'.format(n=i))
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
You evaluate the model, which has been clustered and quantized, and then see the accuracy from TensorFlow persists to the TFLite backend.
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
test_accuracy = eval_model(interpreter)
print('Clustered and quantized TFLite test_accuracy:', test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package. See the migration guide for details.
warnings.warn(_INTERPRETER_DELETION_WARNING) INFO: Created TensorFlow Lite XNNPACK delegate for CPU. Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far.
Clustered and quantized TFLite test_accuracy: 0.9774 Clustered TF test accuracy: 0.9782999753952026
Conclusion
In this tutorial, you saw how to create clustered models with the TensorFlow Model Optimization Toolkit API. More specifically, you've been through an end-to-end example for creating an 8x smaller model for MNIST with minimal accuracy difference. We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.