...

/

Training the ResNet Model

Training the ResNet Model

Learn how to train a distributed ResNet model with Flax.

Apart from designing custom CNN architectures, we can use architectures that have already been built. ResNet is one such popular architecture. In most cases, we can achieve better performance by using such architectures. In this chapter, we will learn how to perform distributed training of a ResNet model in Flax.

In this code example, we will train the model from scratch‚ meaning that we will not use the pretrained weights. In a separate chapter, we’ve covered how to perform transfer learning with ResNet.

Prior to training the ResNet model, it’s important to process the data, which was covered previously.

Instantiate the Flax ResNet model

With the data in place, we instantiate the Flax ResNet model using the flaxmodels package. The instantiation requires:

  • The type of output
  • The pretrained argument to load the pretrained model, in this case, None
  • The desired number of classes
  • The data type
Press + to interact
import jax.numpy as jnp
import flaxmodels as fm
num_classes = 2
dtype = jnp.float32
model = fm.ResNet50(output='log_softmax', pretrained=None, num_classes=num_classes, dtype=dtype)

In the code above:

  • Lines 1–2: We import the required libraries: the JAX version of NumPy as jnp and flaxmodels as fm.

  • Lines 4–5: We define two variables: num_classes to store the number of classes and dtype to store the data type for numerical computations.

  • Line 6: We call the fm.ResNet50() method to create an instance of the ResNet-50 architecture. We set the output parameter to log_softmax to apply the LogSoftmax activation to the output layer. We set the pretrained parameter as None, which means we are using the untrained model. We also pass the number of classes and data type.

There are two instances of the ResNet block: the original ResNet block and the pre-activation ResNet block. Both have their pros and cons. A visual representation of both blocks is shown below:

Press + to interact
Original and preactivated ResNet block
Original and preactivated ResNet block

Compute metrics

We define the metrics for evaluating the model during training. Let’s start by creating the loss function.

Press + to interact
import jax
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In the code above:

  • Line 1: We import the jax library to get the one-hot encoding of the labels.

  • Lines 3–5: We define the cross_entropy_loss() function that receives two parameters: logits are the output of the model and labels are the actual labels. Inside this function:

    • Line 4: We call the one_hot() method of the jax.nn module to get the one-hot encoding of the labels.

    • Line 5: Lastly, we call the softmax_cross_entropy() method of the optax library to compute the loss. We calculate the mean of the loss and return the results.

Next, we define a function that computes and returns the loss and accuracy.

Press + to interact
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics

We define the compute_metrics() function to compute the loss and accuracy. Inside this function:

  • Line 2: We call the cross_entropy_loss() function to calculate the loss.

  • Line 3: We calculate the accuracy by comparing the actual labels with the predicted labels. We calculate the mean and store the result in accuracy.

  • Lines 4–7: We define the metrics dictionary with two metrics: loss and accuracy.

  • Line 8: Lastly, we return the computed metrics.

Create the Flax model training state

Flax provides a training state for storing training information. The training state can be modified to add new information. In this case, we need to alter the training state to add the batch statistics since the ResNet model computes batch_stats.

Press + to interact
import flax
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict

In the code above:

  • Lines 1–2: We import the flax library and train_state ...