Image Classification with JAX and Flax
Learn how to create models that classify images using JAX and Flax.
We'll cover the following...
Flax is a neural network library for JAX. JAX is a Python library that provides high-performance computing in machine learning research. JAX provides an API similar to NumPy, making it easy to adopt. JAX also includes other functionalities for improving machine learning research. They include:
- Automatic differentiation: JAX supports the forward and reverse automatic differential of numerical functions with functions such as
jacrev,grad, andhessian. - Vectorization: JAX supports automatic vectorization via the
vmapfunction. It also makes it easy to parallelize large-scale data processing via thepmapfunction. - JIT compilation: JAX uses XLA for just-in-time (JIT) compilation and execution of code on GPUs and TPUs.
In this lesson, we will use JAX and Flax to build a simple convolutional neural network.
Loading the dataset
We’ll use the Cats and Dogs dataset we used in previous lessons. Let’s unzip the zip file containing the dataset.
import zipfilewith zipfile.ZipFile('../train.zip', 'r') as zip_ref:zip_ref.extractall('.')
In the code above:
Line 1: We import the
zipfilelibrary.Lines 3–4: We call the
ZipFile()method of thezipfilemodule to open the zip file in read mode aszip_ref. We use thewithstatement to automatically close the file after the code execution. We call theextractall()method to extract the content of the zip file in the current directory.
Flax doesn’t ship with any data loading tools. We can use the data loaders from PyTorch or TensorFlow. In this case, let’s load the data using PyTorch. The first step is to define the dataset class.
from PIL import Imageimport pandas as pdfrom torch.utils.data import Datasetclass CatsDogsDataset(Dataset):def __init__(self, root_dir, annotation_file, transform=None):self.root_dir = root_dirself.annotations = pd.read_csv(annotation_file)self.transform = transformdef __len__(self):return len(self.annotations)def __getitem__(self, index):img_id = self.annotations.iloc[index, 0]img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")y_label = torch.tensor(float(self.annotations.iloc[index, 1]))if self.transform is not None:img = self.transform(img)return (img, y_label)
In the code above:
Lines 1–3: We import the required libraries: the
Imagemodule fromPIL,pandas, and theDatasetmodule fromtorch.utils.data.Lines 5–22: We define a dataset class,
CatsDogsDataset, to load the images and labels. Inside this class:Lines 6–9: We define the constructor function,
__init__(), that receives the root directory and annotation file as parameters. This function also receives an optional parameter,transform, to apply the transformation to the images. The constructor initializes the class instance variables.Lines 11–12: We define the
__len__()function that returns the number of elements in the dataset.Lines 14–22: We define the
__getitem__()function that receives oneindexparameter. We retrieve theimg_idandy_labelfrom the annotation file and store the image inimg. We apply the transformation (if any) and return the image and its label.
Next, we create a pandas DataFrame that will contain the categories.
import ostrain_df = pd.DataFrame(columns=["img_path","label"])train_df["img_path"] = os.listdir("train/")for idx, i in enumerate(os.listdir("train/")):if "cat" in i:train_df["label"][idx] = 0if "dog" in i:train_df["label"][idx] = 1print(train_df)train_df.to_csv(r'train_csv.csv', index = False, header=True)
In the code above:
Line 1: We import the
oslibrary.Line 3: We create a pandas DataFrame,
train_df, with two columns:img_pathto record the file path andlabelfor storing the labels.Line 4: We call the
listdir()method of theosmodule to populate all file names present in thetraindirectory into theimg_pathcolumn oftrain_df.Lines 6–10: We use a
forloop to iterate through each file name in thetraindirectory. We assign label 0 if the file containscatand label 1 if the file containsdog.Lines 12–13: Lastly, we print the DataFrame
train_dfand call theto_csv()method oftrain_dfto convert it into a CSV file.
Define a function that will stack the data and return it as NumPy arrays.
import numpy as npdef custom_collate_fn(batch):transposed_data = list(zip(*batch))labels = np.array(transposed_data[1])imgs = np.stack(transposed_data[0])return imgs, labels
In the code above:
Line 1: We import the NumPy library as ...