Quantization aware training in Keras example (original) (raw)

Overview

Welcome to an end-to-end example for quantization aware training.

Other pages

For an introduction to what quantization aware training is and to determine if you should use it (including what's supported), see the overview page.

To quickly find the APIs you need for your use case (beyond fully-quantizing a model with 8-bits), see thecomprehensive guide.

Summary

In this tutorial, you will:

  1. Train a keras model for MNIST from scratch.
  2. Fine tune the model by applying the quantization aware training API, see the accuracy, and export a quantization aware model.
  3. Use the model to create an actually quantized model for the TFLite backend.
  4. See the persistence of accuracy in TFLite and a 4x smaller model. To see the latency benefits on mobile, try out the TFLite examples in the TFLite app repository.

Setup

pip install -q tensorflow pip install -q tensorflow-model-optimization

import tempfile
import os

import tensorflow as tf

from tensorflow_model_optimization.python.core.keras.compat import keras

2025-05-01 11:57:06.556711: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1746100626.580179 44767 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1746100626.587367 44767 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1746100626.605508 44767 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746100626.605536 44767 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746100626.605539 44767 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1746100626.605542 44767 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Train a model for MNIST without quantization aware training

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)

2025-05-01 11:57:10.523945: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 1688/1688 [==============================] - 7s 4ms/step - loss: 0.2928 - accuracy: 0.9161 - val_loss: 0.1193 - val_accuracy: 0.9695 <tf_keras.src.callbacks.History at 0x7fb150eedd90>

Clone and fine-tune pre-trained model with quantization aware training

Define the model

You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by "quant".

Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.

In the comprehensive guide, you can see how to quantize some layers for model accuracy improvements.

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param #

quantize_layer (QuantizeLa (None, 28, 28) 3
yer)

quant_reshape (QuantizeWra (None, 28, 28, 1) 1
pperV2)

quant_conv2d (QuantizeWrap (None, 26, 26, 12) 147
perV2)

quant_max_pooling2d (Quant (None, 13, 13, 12) 1
izeWrapperV2)

quant_flatten (QuantizeWra (None, 2028) 1
pperV2)

quant_dense (QuantizeWrapp (None, 10) 20295
erV2)

Total params: 20448 (79.88 KB) Trainable params: 20410 (79.73 KB) Non-trainable params: 38 (152.00 Byte)


Train and evaluate the model against baseline

To demonstrate fine tuning after training the model for just an epoch, fine tune with quantization aware training on a subset of the training data.

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)

2/2 [==============================] - 1s 204ms/step - loss: 0.1449 - accuracy: 0.9600 - val_loss: 0.1956 - val_accuracy: 0.9500 <tf_keras.src.callbacks.History at 0x7fb05b2bd640>

For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline.

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9623000025749207 Quant test accuracy: 0.9631999731063843

Create quantized model for TFLite backend

After this, you have an actually quantized model with int8 weights and uint8 activations.

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9qotso6d/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9qotso6d/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1746100643.396464 44767 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746100643.396500 44767 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. I0000 00:00:1746100643.416650 44767 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled

See persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TF Lite model on the test dataset.

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package. See the migration guide for details.

warnings.warn(_INTERPRETER_DELETION_WARNING) INFO: Created TensorFlow Lite XNNPACK delegate for CPU. Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far.

Quant TFLite test_accuracy: 0.9632 Quant TF test accuracy: 0.9631999731063843

See 4x smaller model from quantization

You create a float TFLite model and then see that the quantized TFLite model is 4x smaller.

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmplh2urc2w/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmplh2urc2w/assets Float model in Mb: 0.08069610595703125 Quantized model in Mb: 0.02361297607421875 W0000 00:00:1746100644.845711 44767 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1746100644.845737 44767 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.

Conclusion

In this tutorial, you saw how to create quantization aware models with the TensorFlow Model Optimization Toolkit API and then quantized models for the TFLite backend.

You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy difference. To see the latency benefits on mobile, try out the TFLite examples in the TFLite app repository.

We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.