Exercise: Create and Train a CNN for Classification
Challenge yourself to create and train a CNN to classify handwritten digits from the MNIST dataset.
We'll cover the following...
Problem statement
In this lesson, we’ll create and train a CNN to classify 28x28 grayscale images of handwritten digits. We’ll use the MNIST dataset of 60000 images. Let’s display a few examples from the training dataset.
Press + to interact
import torchvisionimport randomimport matplotlib.pyplot as plt# Create an array of imagesnumber_of_rows = 3number_of_columns = 4fig, axs = plt.subplots(3, 4)# Load the MNIST datasettransform = torchvision.transforms.ToTensor()mnist_dataset = torchvision.datasets.MNIST("./", train=True, download=True, transform=transform)for row in range(number_of_rows):for col in range(number_of_columns):index = random.randint(5, len(mnist_dataset) - 1) # Choose a random index from the datasetimg_tsr, class_ndx = mnist_dataset[index] # Get the image tensor and the target class indexaxs[row, col].imshow(img_tsr.squeeze(0).numpy())plt.tight_layout()plt.savefig('./output/0_images_sample.png')
Attempting to classify these images with traditional computer vision programming would be difficult. This is one of those cases where we can visually identify these images as 1 of 10 digits without effort, but we would have a hard time explaining how we got to the correct answer. It is the perfect application for machine learning. ...