Image Classification using ResNet (original) (raw)

Last Updated : 23 Jul, 2025

This article will walk you through the steps to implement it for image classification using Python and TensorFlow/Keras.

**Image classification classifies an image into one of several predefined categories. **ResNet (Residual Networks), which introduced the concept of **residual connections to address the vanishing gradient problem in very deep neural networks.

Here are the key reasons to use **ResNet for image classification:

Image Classification Using ResNet on CIFAR-10

Here’s a step-by-step guide to implement image classification using the CIFAR-10 dataset and ResNet50 in TensorFlow:

1. **Import Libraries

We begin by importing the necessary libraries from TensorFlow and Keras:

Python `

import tensorflow as tf from tensorflow.keras.applications import ResNet50 from tensorflow.keras.datasets import cifar10 from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, GlobalAveragePooling2D from tensorflow.keras.optimizers import Adam from tensorflow.keras.utils import to_categorical

`

2. Load and Preprocess the CIFAR-10 Dataset

We load the CIFAR-10 dataset using tensorflow.keras.datasets.cifar10. Then, we normalize the pixel values of the images (by dividing by 255) to scale them to a range of 0 to 1. Lastly, we one-hot encode the labels to match the output format for categorical classification.

Python `

Load CIFAR-10 dataset

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

Preprocess the data

x_train = x_train.astype('float32') / 255.0 x_test = x_test.astype('float32') / 255.0

One-hot encode the labels

y_train = to_categorical(y_train, 10) y_test = to_categorical(y_test, 10)

`

3. Load ResNet50 Pre-trained on ImageNet

We use ResNet50, pre-trained on the ImageNet dataset. The **include_top=False parameter ensures that the fully connected layers (the classification head) are not included, so we can add our custom layers.

Python `

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

Freeze the base model

base_model.trainable = False

`

4. Build the Classification Model

We now build the model using the pre-trained **ResNet50 as a base. We add a **GlobalAveragePooling2D layer to reduce the dimensions of the feature maps from the ResNet base model, followed by a **Dense layer for classification.

The final layer has 10 neurons, one for each class in the CIFAR-10 dataset, with a **softmax activation function.

Python `

Build the classification model

model = Sequential([ base_model, GlobalAveragePooling2D(), Dense(1024, activation='relu'), Dense(10, activation='softmax')
])

`

5. Compile the Model

We use the **Adam optimizer with a small learning rate to prevent overfitting and use **categorical cross-entropy as the loss function for multi-class classification. We also track the **accuracy metric during training.

Python `

Compile the model

model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

`

6. Train the Model

We then train the model on the CIFAR-10 training data, using a batch size of 64 and 10 epochs. We also pass the test data for validation during training to monitor the model’s performance.

Python `

Train the model

model.fit(x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test))

`

7. Evaluate the Model

Once the model is trained, we evaluate it on the test data to check its accuracy.

Python `

Evaluate the model

test_loss, test_acc = model.evaluate(x_test, y_test) print(f"Test accuracy: {test_acc}")

`

**Output:

**Test accuracy: 0.8741999864578247

ResNet's residual connections enable us to train very deep models, and its pre-trained weights, when fine-tuned for specific tasks, can provide remarkable accuracy even with smaller datasets. By freezing the early layers of the model, we can focus on learning the final decision-making layers, which is ideal for many real-world applications in image classification.