ResNet Model Definition
Learn how to define a ResNet model.
Training large neural networks can take days or weeks. Once these networks are trained, we can use their weights and apply them to new tasks, i.e., transfer learning. As a result, we can fine-tune a new network and get good results in a short period. Let’s look at how we can fine-tune a pretrained ResNet network in JAX and Flax.
Prior to using transfer learning and fine-tuning the ResNet model, it's important to process the data, which was covered previously.
Pretrained ResNet models are trained on many classes. However, the dataset we have contains two classes. Therefore, we use the ResNet as the backbone and define a custom classification layer.
Create a Head
network
Create a Head
network with output as per the problem, in this case, a binary image classification.
from flax import linen as nnfrom functools import partialclass Head(nn.Module):'''head model'''batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)@nn.compactdef __call__(self, inputs, train: bool):output_n = inputs.shape[-1]x = self.batch_norm_cls(use_running_average=not train)(inputs)x = nn.Dropout(rate=0.25)(x, deterministic=not train)x = nn.Dense(features=output_n)(x)x = nn.relu(x)x = self.batch_norm_cls(use_running_average=not train)(x)x = nn.Dropout(rate=0.5)(x, deterministic=not train)x = nn.Dense(features=config["NUM_LABELS"])(x)return x
In the code above, we import the linen
module from the flax
library as nn
to define the neural network architecture and the partial
module from the functools
library to create a partial function application. We define the Head
class inheriting from the nn.Module
to represent the head model. Inside this class:
Line 6: We call the
partial()
function and apply thenn.BatchNorm
class with a fixed value ofmomentum
to define a partial functionbatch_norm_cls
. We can use this partial function to create the instances of thenn.BatchNorm
layer with a specificmomentum
value.Lines 7–17: We define a
__call__()
function and apply the@nn.compact
decorator to it. The__call__()
function defines the model layers and the forward-passing of the input. Inside this function:Line 9: We calculate the output features from the last dimension of the
input
and store it inoutout_n
.Lines 10–13: We call the
batch_norm_cls()
function to apply the batch normalization layer to theinputs
. We apply theDropout
layer with therate
of 25% on the output of the previous layer. We apply theDense
layer withoutput_n
features and ReLU activation, respectively.Lines 14–17: Similarly, we apply the batch normalization layer,
Dropout
layer, with therate
of 50% and theDense
layer with theNUM_LABELS
features. Lastly, we return the output.
Combine ResNet backbone with head
Combine the pretrained ResNet backbone with the custom head we created above.
from jax_resnet import pretrained_resnet, slice_variables, Sequentialimport jax.numpy as jnpclass Model(nn.Module):backbone: Sequentialhead: Headdef __call__(self, inputs, train: bool):x = self.backbone(inputs)# average pool layerx = jnp.mean(x, axis=(1, 2))x = self.head(x, train)return x
In the code above:
Lines 1–2: We import the required library modules:
pretrained_resnet
,slice_variables
, andSequential
from thejax_resnet
library and the JAX version of NumPy asjnp
.Lines 4–13: We define the
Model
class inheriting thenn.Module
. We define thebackbone
attribute of theSequential
type and thehead
attribute of theHead
type. We define the__call__()
function to apply the model to the given input. Inside this function:Line 9: We apply the
backbone
model to the given input and store the output in the variablex
.Line 11: We call the
jnp.mean()
method to compute the mean of the valuex
and update the valuex
.Lines 12–13: We apply the
head
model to the valuex
and return the output.
Load pretrained ResNet-50
Next, we create a function that loads the pretrained ResNet model. We omit the last two layers of the network because we have defined a custom head. The function returns the ResNet model and its parameters. The model parameters are obtained using the slice_variables
function.
def get_backbone_and_params(model_arch: str):if model_arch == 'resnet50':resnet_tmpl, params = pretrained_resnet(50)model = resnet_tmpl()else:raise NotImplementedError# get model & param structure for backbonestart, end = 0, len(model.layers) - 2backbone = Sequential(model.layers[start:end])backbone_params = slice_variables(params, start, end)return backbone, backbone_params
We define the get_backbone_and_params()
function to retrieve the backbone
model and related parameters. This function receives a string argument, model_arch
, that specifies the model architecture to use. Inside this function: