...

/

Image Classification with JAX and Flax

Image Classification with JAX and Flax

Learn how to create models that classify images using JAX and Flax.

Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy, making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:

  • Automatic differentiation: JAX supports the forward and reverse automatic differential of numerical functions with functions such as jacrev, grad, and hessian.
  • Vectorization: JAX supports automatic vectorization via the vmap function. It also makes it easy to parallelize large-scale data processing via the pmap function.
  • JIT compilation: JAX uses XLA for just-in-time (JIT) compilation and execution of code on GPUs and TPUs.

In this lesson, we will use JAX and Flax to build a simple convolutional neural network.

Loading the dataset

We’ll use the Cats and Dogs dataset we used in previous lessons. Let’s unzip the zip file containing the dataset.

Press + to interact
import zipfile
with zipfile.ZipFile('../train.zip', 'r') as zip_ref:
zip_ref.extractall('.')

In the code above:

  • Line 1: We import the zipfile library.

  • Lines 3–4: We call the ZipFile() method of the zipfile module to open the zip file in read mode as zip_ref. We use the with statement to automatically close the file after the code execution. We call the extractall() method to extract the content of the zip file in the current directory.

Flax doesn’t ship with any data loading tools. We can use the data loaders from PyTorch or TensorFlow. In this case, let’s load the data using PyTorch. The first step is to define the dataset class.

Press + to interact
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset
class CatsDogsDataset(Dataset):
def __init__(self, root_dir, annotation_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotation_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_id = self.annotations.iloc[index, 0]
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
y_label = torch.tensor(float(self.annotations.iloc[index, 1]))
if self.transform is not None:
img = self.transform(img)
return (img, y_label)

In the code above:

  • Lines 1–3: We import the required libraries: the Image module from PIL, pandas, and the Dataset module from torch.utils.data.

  • Lines 5–22: We define a dataset class, CatsDogsDataset, to load the images and labels. Inside this class:

    • Lines 6–9: We define the constructor function, __init__(), that receives the root directory and annotation file as parameters. This function also receives an optional parameter, transform, to apply the transformation to the images. The constructor initializes the class instance variables.

    • Lines 11–12: We define the __len__() function that returns the number of elements in the dataset.

    • Lines 14–22: We define the __getitem__() function that receives one index parameter. We retrieve the img_id and y_label from the annotation file and store the image in img. We apply the transformation (if any) and return the image and its label.

Next, we create a pandas DataFrame that will contain the categories.

Press + to interact
import os
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
if "cat" in i:
train_df["label"][idx] = 0
if "dog" in i:
train_df["label"][idx] = 1
print(train_df)
train_df.to_csv(r'train_csv.csv', index = False, header=True)

In the code above:

  • Line 1: We import the os library.

  • Line 3: We create a pandas DataFrame, train_df, with two columns: img_path to record the file path and label for storing the labels.

  • Line 4: We call the listdir() method of the os module to populate all file names present in the train directory into the img_path column of train_df.

  • Lines 6–10: We use a for loop to iterate through each file name in the train directory. We assign label 0 if the file contains cat and label 1 if the file contains dog.

  • Lines 12–13: Lastly, we print the DataFrame train_df and call the to_csv() method of train_df to convert it into a CSV file.

Define a function that will stack the data and return it as NumPy arrays.

Press + to interact
import numpy as np
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
labels = np.array(transposed_data[1])
imgs = np.stack(transposed_data[0])
return imgs, labels

In the code above:

  • Line 1: We import the NumPy library as ...