Research Breakthrough Possible @S-Logix pro@slogix.in

Office Address

Social List

How to Classify Handwritten Digits Using Deep Learning with Keras in Python?

Hand Written Digits Classification Using Deep Learning

Condition for Classifying Handwritten Digits Using Deep Learning with Keras in Python

  • Description:
    This code implements a Convolutional Neural Network (CNN) for classifying handwritten digits from the MNIST dataset. It visualizes sample images, preprocesses data, builds a CNN model, trains it, and evaluates the model's performance using metrics like accuracy, precision, recall, F1 score, and a confusion matrix.
Step-by-Step Process
  • Step1: Import necessary libraries such as TensorFlow, Matplotlib, Scikit-learn, and Seaborn for model building, evaluation, and visualization.
  • Step2: The MNIST dataset is loaded using TensorFlow's built-in function, providing the training and testing data for digits 0-9.
  • Step3: A function plot_sample_images is defined to visualize the first 15 images from the training set along with their corresponding labels.
  • Step4: The pixel values of the images are normalized by dividing by 255 to scale them to a range between 0 and 1.
  • Step5: The image data is reshaped from a 2D array (28x28) to a 3D array (28x28x1) to match the input requirements of the CNN model.
  • Step6: A Convolutional Neural Network is defined with two convolutional layers, max-pooling layers, flattening, and dense layers, followed by a softmax output layer for multi-class classification.
  • Step7: The model is compiled with the Adam optimizer and sparse categorical cross-entropy loss function, with accuracy as the evaluation metric.
  • Step8: The CNN model is trained using the training set for 10 epochs with a batch size of 16, using the test set for validation.
  • Step9: Predictions are made on the test set using the trained model, and the predicted class labels are extracted by applying the argmax function on the model's output.
  • Step10: The model's performance is evaluated using classification metrics such as accuracy, precision, recall, and F1 score, and a confusion matrix is visualized using Seaborn.
Sample Code
  • import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
    from tensorflow.keras.models import Model
    from sklearn.metrics import (classification_report,confusion_matrix,accuracy_score,
    f1_score,recall_score,precision_score)

    import warnings
    warnings.filterwarnings("ignore")

    import seaborn as sns

    # Load MNIST dataset
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    # Class names for MNIST (digits 0-9)
    class_names = [str(i) for i in range(10)] # Digits from 0 to 9

    # Print the shape of the dataset
    print("Training data shape:", x_train.shape)
    print("Testing data shape:", x_test.shape)

    # Optionally, print the class names
    print("Class names:", class_names)

    # Function to plot sample images
    def plot_sample_images(images, labels, class_names, rows=3, cols=5):
     plt.figure(figsize=(10, 7))
     for i in range(rows * cols):
      plt.subplot(rows, cols, i + 1)
      plt.imshow(images[i], cmap="gray")
      plt.title(class_names[labels[i]])
      plt.axis("off")
     plt.tight_layout()
     plt.show()

    # Visualize the first 15 samples from the training set
    plot_sample_images(x_train, y_train, class_names)

    # data_normalization
    x_train = x_train/255
    x_test = x_test/255
    x_train = x_train.reshape(x_train.shape[0],x_train.shape[1],x_train.shape[2],1)
    x_test = x_test.reshape(x_test.shape[0],x_test.shape[1],x_test.shape[2],1)

    def CNN_model(input_shape):
     # Input layer (for image data, input_shape includes height, width, and channels)
     inputs = Input(shape=input_shape)
     # Convolutional layers
     conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
     pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
     conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
     pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
     # Flatten the 2D feature maps into a 1D feature vector
     flatten = Flatten()(pool2)
     # Fully connected (Dense) layers
     dense1 = Dense(64, activation='relu')(flatten)
     dense2 = Dense(32,activation='relu')(dense1)
     output_layer = Dense(10, activation='softmax')(dense2)
     # Build the model
     cnn_model = Model(inputs=inputs, outputs=output_layer)
     # Compile the model with Adam optimizer and binary crossentropy loss function
     cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
     return cnn_model
    model = CNN_model((x_train.shape[1],x_train.shape[2],1))
    model.fit(x_train,y_train,batch_size=16,epochs=10,validation_data=(x_test,y_test))
    y_pred = model.predict(x_test)
    y_pred = [np.argmax(i) for i in y_pred]

    print("___Performance_Metrics___\n")
    print('Classification_Report:\n',classification_report(y_test, y_pred))
    print('Confusion_Matrix:\n',confusion_matrix(y_test, y_pred))

    print('\n')
    print('Accuracy_Score: ',accuracy_score(y_test, y_pred))
    print('F1_Score (macro): ', f1_score(y_test, y_pred, average='macro'))
    print('Recall_Score (macro): ', recall_score(y_test, y_pred, average='macro'))
    print('Precision_Score (macro): ', precision_score(y_test, y_pred, average='macro'))

    # Plot Confusion Matrix
    # Compute confusion matrix
    cm = confusion_matrix(y_test, y_pred)

    # Plot confusion matrix using seaborn heatmap
    plt.figure(figsize=(6,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
    ], yticklabels=[
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
    ])
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()
Screenshots
  • Hand Written Digits