Training the ResNet Model
Learn how to train a distributed ResNet model with Flax.
We'll cover the following...
Apart from designing custom CNN architectures, we can use architectures that have already been built. ResNet is one such popular architecture. In most cases, we can achieve better performance by using such architectures. In this chapter, we will learn how to perform distributed training of a ResNet model in Flax.
In this code example, we will train the model from scratch‚ meaning that we will not use the pretrained weights. In a separate chapter, we’ve covered how to perform transfer learning with ResNet.
Prior to training the ResNet model, it’s important to process the data, which was covered previously.
Instantiate the Flax ResNet model
With the data in place, we instantiate the Flax ResNet model using the flaxmodels
package. The instantiation requires:
- The type of output
- The
pretrained
argument to load the pretrained model, in this case,None
- The desired number of classes
- The data type
import jax.numpy as jnpimport flaxmodels as fmnum_classes = 2dtype = jnp.float32model = fm.ResNet50(output='log_softmax', pretrained=None, num_classes=num_classes, dtype=dtype)
In the code above:
Lines 1–2: We import the required libraries: the JAX version of NumPy as
jnp
andflaxmodels
asfm
.Lines 4–5: We define two variables:
num_classes
to store the number of classes anddtype
to store the data type for numerical computations.Line 6: We call the
fm.ResNet50()
method to create an instance of the ResNet-50 architecture. We set theoutput
parameter tolog_softmax
to apply the LogSoftmax activation to the output layer. We set thepretrained
parameter asNone
, which means we are using the untrained model. We also pass the number of classes and data type.
There are two instances of the ResNet block: the original ResNet block and the pre-activation ResNet block. Both have their pros and cons. A visual representation of both blocks is shown below:
Compute metrics
We define the metrics for evaluating the model during training. Let’s start by creating the loss function.
import jaxdef cross_entropy_loss(*, logits, labels):labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
In the code above:
Line 1: We import the
jax
library to get the one-hot encoding of the labels.Lines 3–5: We define the
cross_entropy_loss()
function that receives two parameters:logits
are the output of the model andlabels
are the actual labels. Inside this function:Line 4: We call the
one_hot()
method of thejax.nn
module to get the one-hot encoding of thelabels
.Line 5: Lastly, we call the
softmax_cross_entropy()
method of theoptax
library to compute the loss. We calculate the mean of the loss and return the results.
Next, we define a function that computes and returns the loss and accuracy.
def compute_metrics(*, logits, labels):loss = cross_entropy_loss(logits=logits, labels=labels)accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)metrics = {'loss': loss,'accuracy': accuracy,}return metrics
We define the compute_metrics()
function to compute the loss and accuracy. Inside this function:
Line 2: We call the
cross_entropy_loss()
function to calculate theloss
.Line 3: We calculate the accuracy by comparing the actual
labels
with the predicted labels. We calculate the mean and store the result inaccuracy
.Lines 4–7: We define the
metrics
dictionary with two metrics:loss
andaccuracy
.Line 8: Lastly, we return the computed
metrics
.
Create the Flax model training state
Flax provides a training state for storing training information. The training state can be modified to add new information. In this case, we need to alter the training state to add the batch statistics since the ResNet model computes batch_stats
.
import flaxfrom flax.training import train_stateclass TrainState(train_state.TrainState):batch_stats: flax.core.FrozenDict
In the code above:
Lines 1–2: We import the
flax
library andtrain_state
...