Search⌘ K

Exercise: Create and Train a CNN for Classification

Understand how to build and train a convolutional neural network to classify 28x28 grayscale images of digits. This lesson guides you through the creation of the network, training and validation processes, and testing your model on handwritten digit data.

Problem statement

In this lesson, we’ll create and train a CNN to classify 28x28 grayscale images of handwritten digits. We’ll use the MNIST dataset of 60000 images. Let’s display a few examples from the training dataset.

C++
import torchvision
import random
import matplotlib.pyplot as plt
# Create an array of images
number_of_rows = 3
number_of_columns = 4
fig, axs = plt.subplots(3, 4)
# Load the MNIST dataset
transform = torchvision.transforms.ToTensor()
mnist_dataset = torchvision.datasets.MNIST("./", train=True, download=True, transform=transform)
for row in range(number_of_rows):
for col in range(number_of_columns):
index = random.randint(5, len(mnist_dataset) - 1) # Choose a random index from the dataset
img_tsr, class_ndx = mnist_dataset[index] # Get the image tensor and the target class index
axs[row, col].imshow(img_tsr.squeeze(0).numpy())
plt.tight_layout()
plt.savefig('./output/0_images_sample.png')

Attempting to classify these images with traditional computer vision programming would be difficult. This is one of those cases where we can visually identify these images as 1 of 10 digits without effort, but we would have a hard time explaining how we got to the correct answer. It is the perfect application for machine learning.

Launch the Jupyter Notebook below and follow the instructions.

Instructions:

  • The comment lines ...