MNIST CNN

Learn how to create an MNIST Convolutional Neural Network.

To build our familiarity with convolutions in a neural network, let’s make an MNIST classifier before we try making a GAN that uses them.

We start by making a copy of our previous MNIST classifier given here.

📝 We’ll only need to change the definition of the classifier neural network. The rest of the code for loading the data, viewing images, training the network, and checking classification performance shouldn’t need to change much.

That neural network had an input layer of 784 nodes, fully connected to a middle layer of 200 nodes, which itself was fully connected to an output layer of 10 nodes. The middle layer had a LeakyReLU activation and a layer normalization applied after it. The output layer simply had a sigmoid activation applied. That network achieved a really good 97% accuracy on the MNIST test data.

Working with convolution filters

We now need to think about how to replace this with convolution filters. The first thing that strikes us is that convolutions work on 2-dimensional images, whereas we were feeding this network a simple 1-dimensional list of pixel values. A quick and easy fix is to reshape the image_data_tensor to have shape (28, 28) whenever we pass it to the network.

Defining a CNN architecture

What we actually have to do is use a 4-dimensional tensor because PyTorch’s convolution filters expect data tensors to have 4 elements (batch size, channels, height, width). We’re using a batch size of 1, and the MNIST images are monochrome so only have 1 channel, so our MNIST data needs to be shaped as (1, 1, 28, 28). We can easily do this using the view() function.

Have a look at the following code for a convolutional neural network.

Get hands-on with 1400+ tech skills courses.