Exercise: Image Classification on Fashion-MNIST with CNN

This will be our first example of using a CNN for a real-world machine learning task. We’ll classify images using a CNN. The reason for not starting with an NLP task is that applying CNNs to NLP tasks (for example, sentence classification) is not very straightforward. There are several tricks involved in using CNNs for such a task. However, originally, CNNs were designed to cope with image data. Therefore, let’s start there and then find our way through to see how CNNs apply to NLP tasks.

About the data

In this exercise, we’ll use a well-known dataset in the computer vision community: the Fashion-MNIST dataset. Fashion-MNIST was inspired by the famous MNIST dataset. MNIST is a database of labeled images of handwritten digits from 0 to 9 (i.e., 10 digits). However, due to the simplicity of the MNIST image classification task, test accuracy on MNIST is just shy of 100%. At the time of writing, the popular research benchmarking site paperswithcode.com has published a test accuracy of 99.87%. Because of this, Fashion-MNIST came to life.

Fashion-MNIST consists of images of clothing garments. Our task is to classify each garment into a category (e.g., dress, T-shirt). The dataset contains two sets: the training set and the test set. We’ll train on the training set and evaluate the performance of our model on the unseen test dataset. We’ll further split the training set into two sets: training and validation sets. We’ll use the validation dataset as a continuous performance monitoring mechanism for our model. We’ll discuss the details later, but we’ll see that we can reach up to approximately 88% test accuracy without any special regularization or tricks.

Downloading and exploring the data

The very first task will be to download and explore the data. To download the data, we will simply tap into the tf.keras.datasets module because it provides several datasets to be downloaded conveniently through TensorFlow. You can explore other datasets on the TensorFlow website. Simply call the following function to download the data:

Get hands-on with 1400+ tech skills courses.