...

/

ResNet Model Definition

ResNet Model Definition

Learn how to define a ResNet model.

Training large neural networks can take days or weeks. Once these networks are trained, we can use their weights and apply them to new tasks, i.e., transfer learning. As a result, we can fine-tune a new network and get good results in a short period. Let’s look at how we can fine-tune a pretrained ResNet network in JAX and Flax.

Transfer learning
Transfer learning

Prior to using transfer learning and fine-tuning the ResNet model, it's important to process the data, which was covered previously.

Pretrained ResNet models are trained on many classes. However, the dataset we have contains two classes. Therefore, we use the ResNet as the backbone and define a custom classification layer.

Create a Head network

Create a Head network with output as per the problem, in this case, a binary image classification.

Press + to interact
from flax import linen as nn
from functools import partial
class Head(nn.Module):
'''head model'''
batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
@nn.compact
def __call__(self, inputs, train: bool):
output_n = inputs.shape[-1]
x = self.batch_norm_cls(use_running_average=not train)(inputs)
x = nn.Dropout(rate=0.25)(x, deterministic=not train)
x = nn.Dense(features=output_n)(x)
x = nn.relu(x)
x = self.batch_norm_cls(use_running_average=not train)(x)
x = nn.Dropout(rate=0.5)(x, deterministic=not train)
x = nn.Dense(features=config["NUM_LABELS"])(x)
return x

In the code above, we import the linen module from the flax library as nn to define the neural network architecture and the partial module from the functools library to create a partial function application. We define the Head class inheriting from the nn.Module to represent the head model. Inside this class:

  • Line 6: We call the partial() function and apply the nn.BatchNorm class with a fixed value of momentum to define a partial function batch_norm_cls. We can use this partial function to create the instances of the nn.BatchNorm layer with a specific momentum value.

  • Lines 7–17: We define a __call__() function and apply the @nn.compact decorator to it. The __call__() function defines the model layers and the forward-passing of the input. Inside this function:

    • Line 9: We calculate the output features from the last dimension of the input and store it in outout_n.

    • Lines 10–13: We call the batch_norm_cls() function to apply the batch normalization layer to the inputs. We apply the Dropout layer with the rate of 25% on the output of the previous layer. We apply the Dense layer with output_n features and ReLU activation, respectively.

    • Lines 14–17: Similarly, we apply the batch normalization layer, Dropout layer, with the rate of 50% and the Dense layer with the NUM_LABELS features. Lastly, we return the output.

Combine ResNet backbone with head

Combine the pretrained ResNet backbone with the custom head we created above.

Press + to interact
from jax_resnet import pretrained_resnet, slice_variables, Sequential
import jax.numpy as jnp
class Model(nn.Module):
backbone: Sequential
head: Head
def __call__(self, inputs, train: bool):
x = self.backbone(inputs)
# average pool layer
x = jnp.mean(x, axis=(1, 2))
x = self.head(x, train)
return x

In the code above:

  • Lines 1–2: We import the required library modules: pretrained_resnet, slice_variables, and Sequential from the jax_resnet library and the JAX version of NumPy as jnp.

  • Lines 4–13: We define the Model class inheriting the nn.Module. We define the backbone attribute of the Sequential type and the head attribute of the Head type. We define the __call__() function to apply the model to the given input. Inside this function:

    • Line 9: We apply the backbone model to the given input and store the output in the variable x.

    • Line 11: We call the jnp.mean() method to compute the mean of the value x and update the value x.

    • Lines 12–13: We apply the head model to the value x and return the output.

Load pretrained ResNet-50

Next, we create a function that loads the pretrained ResNet model. We omit the last two layers of the network because we have defined a custom head. The function returns the ResNet model and its parameters. The model parameters are obtained using the slice_variables function.

Press + to interact
def get_backbone_and_params(model_arch: str):
if model_arch == 'resnet50':
resnet_tmpl, params = pretrained_resnet(50)
model = resnet_tmpl()
else:
raise NotImplementedError
# get model & param structure for backbone
start, end = 0, len(model.layers) - 2
backbone = Sequential(model.layers[start:end])
backbone_params = slice_variables(params, start, end)
return backbone, backbone_params

We define the get_backbone_and_params() function to retrieve the backbone model and related parameters. This function receives a string argument, model_arch, that specifies the model architecture to use. Inside this function:

    ...
    Access this course and 1400+ top-rated courses and projects.