Image classification with Model Garden (original) (raw)
TensorFlow basics
Keras
Build with Core
TensorFlow in depth
Customization
Data input pipelines
Import and export
Accelerators
Performance
Model Garden
Estimators
Appendix
Image classification with Model Garden
Stay organized with collections Save and categorize content based on your preferences.
This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow Model Garden package (tensorflow-models
) to classify images in the CIFAR dataset.
Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
This tutorial uses a ResNet model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.
This tutorial demonstrates how to:
- Use models from the TensorFlow Models package.
- Fine-tune a pre-built ResNet for image classification.
- Export the tuned ResNet model.
Setup
Install and import the necessary modules.
pip install -U -q "tf-models-official"
Import TensorFlow, TensorFlow Datasets, and a few helper libraries.
import pprint
import tempfile
from IPython import display
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-17 11:52:54.005237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-17 11:52:54.005294: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-17 11:52:54.005338: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
The tensorflow_models
package contains the ResNet vision model, and the official.vision.serving
model contains the function to save and export the tuned model.
import tensorflow_models as tfm
# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib
Configure the ResNet-18 model for the Cifar-10 dataset
The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.
In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.
Use the resnet_imagenet
factory configuration, as defined by tfm.vision.configs.image_classification.image_classification_imagenet. The configuration is set up to train ResNet to converge on ImageNet.
exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(
tfds_name,
with_info=True)
ds_info
2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... tfds.core.DatasetInfo( name='cifar10', full_name='cifar10/3.0.2', description=""" The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. """, homepage='https://www.cs.toronto.edu/~kriz/cifar.html', data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2', file_format=tfrecord, download_size=162.17 MiB, dataset_size=132.40 MiB, features=FeaturesDict({ 'id': Text(shape=(), dtype=string), 'image': Image(shape=(32, 32, 3), dtype=uint8), 'label': ClassLabel(shape=(), dtype=int64, num_classes=10), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': , 'train': , }, citation="""@TECHREPORT{Krizhevsky09learningmultiple, author = {Alex Krizhevsky}, title = {Learning multiple layers of features from tiny images}, institution = {}, year = {2009} }""", )
Adjust the model and dataset configurations so that it works with Cifar-10 (cifar10
).
# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18
# Configure training and testing data
batch_size = 128
exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size
exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size
Adjust the trainer configuration.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
else:
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
if device=='CPU':
train_steps = 20
exp_config.trainer.steps_per_loop = 5
else:
train_steps=5000
exp_config.trainer.steps_per_loop = 100
exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
Running on CPU is slow, so only train for a few steps.
Print the modified configuration.
pprint.pprint(exp_config.as_dict())
display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': True, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': None, 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'use_tpu_mp_strategy': False, 'worker_hosts': None}, 'task': {'allow_image_summary': False, 'differential_privacy_config': None, 'eval_input_partition_dims': [], 'evaluation': {'precision_and_recall_thresholds': None, 'report_per_class_precision_and_recall': False, 'top_k': 5}, 'freeze_backbone': False, 'init_checkpoint': None, 'init_checkpoint_modules': 'all', 'losses': {'l2_weight_decay': 0.0001, 'label_smoothing': 0.0, 'loss_weight': 1.0, 'one_hot': True, 'soft_labels': False, 'use_binary_cross_entropy': False}, 'model': {'add_head_batch_norm': False, 'backbone': {'resnet': {'bn_trainable': True, 'depth_multiplier': 1.0, 'model_id': 18, 'replace_stem_max_pool': False, 'resnetd_shortcut': False, 'scale_stem': True, 'se_ratio': 0.0, 'stem_type': 'v0', 'stochastic_depth_drop_rate': 0.0}, 'type': 'resnet'}, 'dropout_rate': 0.0, 'input_size': [32, 32, 3], 'kernel_initializer': 'random_uniform', 'norm_activation': {'activation': 'relu', 'norm_epsilon': 1e-05, 'norm_momentum': 0.9, 'use_sync_bn': False}, 'num_classes': 10, 'output_softmax': False}, 'model_output_keys': [], 'name': None, 'train_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'center_crop_fraction': 0.875, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'attribute_names': [], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 128, 'image_field_key': 'image/encoded', 'input_path': '', 'is_multilabel': False, 'is_training': True, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'repeated_augment': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tf_resize_method': 'bilinear', 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': 'cifar10', 'tfds_skip_decoding_feature': '', 'tfds_split': 'train', 'three_augment': False, 'trainer_id': None, 'weights': None}, 'train_input_partition_dims': [], 'validation_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'center_crop_fraction': 0.875, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'attribute_names': [], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 128, 'image_field_key': 'image/encoded', 'input_path': '', 'is_multilabel': False, 'is_training': False, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'repeated_augment': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tf_resize_method': 'bilinear', 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': 'cifar10', 'tfds_skip_decoding_feature': '', 'tfds_split': 'test', 'three_augment': False, 'trainer_id': None, 'weights': None} }, 'trainer': {'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': '', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 20, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': {'ema': None, 'learning_rate': {'cosine': {'alpha': 0.0, 'decay_steps': 20, 'initial_learning_rate': 0.1, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': {'sgd': {'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': {'linear': {'name': 'linear', 'warmup_learning_rate': 0, 'warmup_steps': 100}, 'type': 'linear'} }, 'preemption_on_demand_checkpoint': True, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 5, 'summary_interval': 100, 'train_steps': 20, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 1000, 'validation_steps': 78, 'validation_summary_subdir': 'validation'} } <IPython.core.display.Javascript object>
Set up the distribution strategy.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if exp_config.runtime.mixed_precision_dtype == tf.float16:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Warning: this will be really slow.
Create the Task
object (tfm.core.base_task.Task) from the config_definitions.TaskConfig.
The Task
object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment.
with distribution_strategy.scope():
model_dir = tempfile.mkdtemp()
task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print()
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')
images.shape: (128, 32, 32, 3) images.dtype: tf.float32
labels.shape: (128,) labels.dtype: tf.int32
2023-10-17 11:53:02.248801: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Visualize the training data
The dataloader applies a z-score normalization using preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB), so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range.
plt.hist(images.numpy().flatten());
Use ds_info
(which is an instance of tfds.core.DatasetInfo) to lookup the text descriptions of each class ID.
label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'
Visualize a batch of the data.
def show_batch(images, labels, predictions=None):
plt.figure(figsize=(10, 10))
min = images.numpy().min()
max = images.numpy().max()
delta = max - min
for i in range(12):
plt.subplot(6, 6, i + 1)
plt.imshow((images[i]-min) / delta)
if predictions is None:
plt.title(label_info.int2str(labels[i]))
else:
if labels[i] == predictions[i]:
color = 'g'
else:
color = 'r'
plt.title(label_info.int2str(predictions[i]), color=color)
plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
show_batch(images, labels)
2023-10-17 11:53:04.198417: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Visualize the testing data
Visualize a batch of images from the validation dataset.
plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
show_batch(images, labels)
2023-10-17 11:53:07.007846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Train and evaluate
model, eval_logs = tfm.core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
restoring or initializing model...
INFO:tensorflow:Customized initialization is done through the passed init_fn
.
INFO:tensorflow:Customized initialization is done through the passed init_fn
.
train | step: 0 | training until step 20...
2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
train | step: 5 | steps/sec: 0.5 | output:
{'accuracy': 0.103125,
'learning_rate': 0.0,
'top_5_accuracy': 0.4828125,
'training_loss': 2.7998607}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5.
train | step: 10 | steps/sec: 0.8 | output:
{'accuracy': 0.0828125,
'learning_rate': 0.0,
'top_5_accuracy': 0.4984375,
'training_loss': 2.8205295}
train | step: 15 | steps/sec: 0.8 | output:
{'accuracy': 0.0921875,
'learning_rate': 0.0,
'top_5_accuracy': 0.503125,
'training_loss': 2.8169343}
train | step: 20 | steps/sec: 0.8 | output:
{'accuracy': 0.1015625,
'learning_rate': 0.0,
'top_5_accuracy': 0.45,
'training_loss': 2.8760865}
eval | step: 20 | running 78 steps of evaluation...
eval | step: 20 | steps/sec: 24.4 | eval time: 3.2 sec | output:
{'accuracy': 0.09485176,
'steps_per_second': 24.40085348913806,
'top_5_accuracy': 0.49589342,
'validation_loss': 2.5864375}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20.
2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
eval | step: 20 | running 78 steps of evaluation...
2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
eval | step: 20 | steps/sec: 40.1 | eval time: 1.9 sec | output:
{'accuracy': 0.09485176,
'steps_per_second': 40.14298727815298,
'top_5_accuracy': 0.49589342,
'validation_loss': 2.5864375}
# tf.keras.utils.plot_model(model, show_shapes=True)
Print the accuracy
, top_5_accuracy
, and validation_loss
evaluation metrics.
for key, value in eval_logs.items():
if isinstance(value, tf.Tensor):
value = value.numpy()
print(f'{key:20}: {value:.3f}')
accuracy : 0.095 top_5_accuracy : 0.496 validation_loss : 2.586 steps_per_second : 40.143
Run a batch of the processed training data through the model, and view the results
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
predictions = model.predict(images)
predictions = tf.argmax(predictions, axis=-1)
show_batch(images, labels, tf.cast(predictions, tf.int32))
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
2023-10-17 11:53:49.840600: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
4/4 [==============================] - 1s 13ms/step
2023-10-17 11:53:50.778301: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Export a SavedModel
The keras.Model object returned by train_lib.run_experiment expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB). This export function handles those details, so you can pass tf.uint8 images and get the correct results.
# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[32, 32],
params=exp_config,
checkpoint_path=tf.train.latest_checkpoint(model_dir),
export_dir='./export/')
INFO:tensorflow:Assets written to: ./export/assets INFO:tensorflow:Assets written to: ./export/assets
Test the exported model.
# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']
Visualize the predictions.
plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
predictions = []
for image in data['image']:
index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
predictions.append(index)
show_batch(data['image'], data['label'], predictions)
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
2023-10-17 11:54:01.438509: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-10-17 UTC.