How to Perform Fashion MNIST Classification Using Keras and CNN in Python?
Share
Condition for Fashion MNIST Classification Using Keras and CNN in Python
Description: This code implements a Convolutional Neural Network (CNN) for multi-class classification using the Fashion MNIST dataset. It includes data preprocessing, CNN model creation, training, evaluation, and visualization of performance metrics like accuracy, F1-score, and confusion matrix. The model predicts clothing categories based on grayscale image inputs.
Step-by-Step Process
Step1: Load essential libraries like TensorFlow, NumPy, Matplotlib, and Scikit-learn for data handling, modeling, and evaluation.
Step2: Import the Fashion MNIST dataset, splitting it into training and testing sets.
Step3: Define the names of the 10 clothing categories for labeling and visualization.
Step4: Display 15 sample images from the training set with their corresponding labels.
Step5: Scale pixel values to the range [0, 1] for faster convergence during training.
Step6: Reshape image arrays to include a single channel, making them compatible with the CNN input layer.
Step7: Build a CNN architecture with Conv2D, MaxPooling2D, Flatten, and Dense layers for feature extraction and classification.
Step8: Compile the model using Adam optimizer and sparse categorical crossentropy loss for multi-class classification.
Step9: Train the CNN on the training data for 10 epochs, using a batch size of 16,and validate on the test set.
Step10: Predict test labels and compute metrics like classification report,confusion matrix, accuracy, F1-score, precision, and recall.
Sample Code
#Import Necessary Libraries
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from sklearn.metrics import (classification_report,confusion_matrix,accuracy_score,
f1_score,recall_score,precision_score)
import warnings
warnings.filterwarnings("ignore")
# Load the Fashion MNIST dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# Class names for Fashion MNIST
class_names = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]
# Function to plot sample images
def plot_sample_images(images, labels, class_names, rows=3, cols=5):
plt.figure(figsize=(10, 7))
for i in range(rows * cols):
plt.subplot(rows, cols, i + 1)
plt.imshow(images[i], cmap="gray")
plt.title(class_names[labels[i]])
plt.axis("off")
plt.tight_layout()
plt.show()
# Visualize the first 15 samples from the training set
plot_sample_images(x_train, y_train, class_names)
#data_normalization
x_train = x_train/255
x_test = x_test/255
x_train = x_train.reshape(x_train.shape[0],x_train.shape[1],x_train.shape[2],1)
x_test = x_test.reshape(x_test.shape[0],x_test.shape[1],x_test.shape[2],1)
def CNN_model(input_shape):
# Input layer (for image data, input_shape includes height, width, and channels)
inputs = Input(shape=input_shape)
# Convolutional layers
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
# Flatten the 2D feature maps into a 1D feature vector
flatten = Flatten()(pool2)
# Fully connected (Dense) layers
dense1 = Dense(64, activation='relu')(flatten)
dense2 = Dense(32,activation='relu')(dense1)
output_layer = Dense(10, activation='softmax')(dense2)
# Build the model
cnn_model = Model(inputs=inputs, outputs=output_layer)
# Compile the model with Adam optimizer and binary crossentropy loss function
cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return cnn_model
model = CNN_model((x_train.shape[1],x_train.shape[2],1))
model.fit(x_train,y_train,batch_size=16,epochs=10,validation_data=(x_test,y_test))
y_pred = model.predict(x_test)
y_pred = [np.argmax(i) for i in y_pred]
print("___Performance_Metrics___\n")
print('Classification_Report:\n',classification_report(y_test, y_pred))
print('Confusion_Matrix:\n',confusion_matrix(y_test, y_pred))
print('\n')
print('Accuracy_Score: ',accuracy_score(y_test, y_pred))
print('F1_Score (macro): ', f1_score(y_test, y_pred, average='macro'))
print('Recall_Score (macro): ', recall_score(y_test, y_pred, average='macro'))
print('Precision_Score (macro): ', precision_score(y_test, y_pred, average='macro'))