Image Classification with JAX and Flax
Learn how to create models that classify images using JAX and Flax.
Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy, making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:
- Automatic differentiation: JAX supports the forward and reverse automatic differential of numerical functions with functions such as
jacrev
,grad
, andhessian
. - Vectorization: JAX supports automatic vectorization via the
vmap
function. It also makes it easy to parallelize large-scale data processing via thepmap
function. - JIT compilation: JAX uses XLA for just-in-time (JIT) compilation and execution of code on GPUs and TPUs.
In this lesson, we will use JAX and Flax to build a simple convolutional neural network.
Loading the dataset
We’ll use the Cats and Dogs dataset we used in previous lessons. Let’s unzip the zip file containing the dataset.
import zipfilewith zipfile.ZipFile('../train.zip', 'r') as zip_ref:zip_ref.extractall('.')
In the code above:
Line 1: We import the
zipfile
library.Lines 3–4: We call the
ZipFile()
method of thezipfile
module to open the zip file in read mode aszip_ref
. We use thewith
statement to automatically close the file after the code execution. We call theextractall()
method to extract the content of the zip file in the current directory.
Flax doesn’t ship with any data loading tools. We can use the data loaders from PyTorch or TensorFlow. In this case, let’s load the data using PyTorch. The first step is to define the dataset class.
from PIL import Imageimport pandas as pdfrom torch.utils.data import Datasetclass CatsDogsDataset(Dataset):def __init__(self, root_dir, annotation_file, transform=None):self.root_dir = root_dirself.annotations = pd.read_csv(annotation_file)self.transform = transformdef __len__(self):return len(self.annotations)def __getitem__(self, index):img_id = self.annotations.iloc[index, 0]img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")y_label = torch.tensor(float(self.annotations.iloc[index, 1]))if self.transform is not None:img = self.transform(img)return (img, y_label)
In the code above:
Lines 1–3: We import the required libraries: the
Image
module fromPIL
,pandas
, and theDataset
module fromtorch.utils.data
.Lines 5–22: We define a dataset class,
CatsDogsDataset
, to load the images and labels. Inside this class:Lines 6–9: We define the constructor function,
__init__()
, that receives the root directory and annotation file as parameters. This function also receives an optional parameter,transform
, to apply the transformation to the images. The constructor initializes the class instance variables.Lines 11–12: We define the
__len__()
function that returns the number of elements in the dataset.Lines 14–22: We define the
__getitem__()
function that receives oneindex
parameter. We retrieve theimg_id
andy_label
from the annotation file and store the image inimg
. We apply the transformation (if any) and return the image and its label.
Next, we create a pandas DataFrame that will contain the categories.
import ostrain_df = pd.DataFrame(columns=["img_path","label"])train_df["img_path"] = os.listdir("train/")for idx, i in enumerate(os.listdir("train/")):if "cat" in i:train_df["label"][idx] = 0if "dog" in i:train_df["label"][idx] = 1print(train_df)train_df.to_csv(r'train_csv.csv', index = False, header=True)
In the code above:
Line 1: We import the
os
library.Line 3: We create a pandas DataFrame,
train_df
, with two columns:img_path
to record the file path andlabel
for storing the labels.Line 4: We call the
listdir()
method of theos
module to populate all file names present in thetrain
directory into theimg_path
column oftrain_df
.Lines 6–10: We use a
for
loop to iterate through each file name in thetrain
directory. We assign label 0 if the file containscat
and label 1 if the file containsdog
.Lines 12–13: Lastly, we print the DataFrame
train_df
and call theto_csv()
method oftrain_df
to convert it into a CSV file.
Define a function that will stack the data and return it as NumPy arrays.
import numpy as npdef custom_collate_fn(batch):transposed_data = list(zip(*batch))labels = np.array(transposed_data[1])imgs = np.stack(transposed_data[0])return imgs, labels
In the code above:
Line 1: We import the NumPy library as ...