...

/

Exercise: Create and Train a CNN for Classification

Exercise: Create and Train a CNN for Classification

Challenge yourself to create and train a CNN to classify handwritten digits from the MNIST dataset.

Problem statement

In this lesson, we’ll create and train a CNN to classify 28x28 grayscale images of handwritten digits. We’ll use the MNIST dataset of 60000 images. Let’s display a few examples from the training dataset.

Press + to interact
import torchvision
import random
import matplotlib.pyplot as plt
# Create an array of images
number_of_rows = 3
number_of_columns = 4
fig, axs = plt.subplots(3, 4)
# Load the MNIST dataset
transform = torchvision.transforms.ToTensor()
mnist_dataset = torchvision.datasets.MNIST("./", train=True, download=True, transform=transform)
for row in range(number_of_rows):
for col in range(number_of_columns):
index = random.randint(5, len(mnist_dataset) - 1) # Choose a random index from the dataset
img_tsr, class_ndx = mnist_dataset[index] # Get the image tensor and the target class index
axs[row, col].imshow(img_tsr.squeeze(0).numpy())
plt.tight_layout()
plt.savefig('./output/0_images_sample.png')

Attempting to classify these images with traditional computer vision programming would be difficult. This is one of those cases where we can visually identify these images as 1 of 10 digits without effort, but we would have a hard time explaining how we got to the correct answer. It is the perfect application for machine learning. ...