...

/

Applying BatchNorm and DropOut Layers in JAX and Flax

Applying BatchNorm and DropOut Layers in JAX and Flax

Learn how to apply BatchNorm and DropOut layers in JAX and Flax.

Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can’t have side effects introduces a challenge when dealing with stateful items, such as model parameters, and stateful layers, such as batch normalization layers.

In this lesson, we’ll create a network with the BatchNorm and DropOut layers. After that, we’ll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.

Prior to training and applying layers, it’s important to process the data which was covered previously.

Define Flax model with BatchNorm and DropOut

Let’s define the Flax network with the BatchNorm and DropOut layers. In the network, we introduce the training variable to control when the batch stats should be updated. We ensure that they aren’t updated during testing.

In the BatchNorm layer, we set use_running_average to False, meaning that the stats stored in batch_stats will not be used, but batch stats of the input will be computed. The DropOut layer takes the following rate:

  • The rate dropout probability.
  • Whether it’s deterministic: if it is deterministic, inputs are scaled and masked. Otherwise, they are not masked and returned as they are.
Press + to interact
from flax import linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x, training):
x = nn.Conv(features=128, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.Dense(features=128)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.Dropout(0.2, deterministic=not training)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
x = nn.log_softmax(x)
return x

In the code above:

  • Line 1: We import the linen module from the flax library as nn.

  • Lines 3–23: We define the CNN class that inherits from the Flax module nn.Module. In this class, we define the forward pass __call__() function. Inside the __call__() function:

    • Lines 6–14: We start the forward pass with the convolution layer with 128 filters and a kernel size of 3 by 3. We apply ReLU activation followed by average pooling with a window shape and stride of 2 by 2. Similarly, the same layers are added again with 64 and 32 filters of the convolution layer.

    • Lines 15–17: The output of the above layers is flattened and passes through two Dense layers with 256 and 128 features, respectively.

    • Lines 18–19: We apply the batch normalization and the dropout layer with a 20% of dropout rate.

    • Lines 20–23: We apply ReLU activation again and pass the output from the third Dense layer with 2 features for binary classification. Lastly, we apply the LogSoftmax activation function and return the output.

Create loss function

The next step is to create the loss function. When applying the model, we:

  • Pass the batch_stats parameters.
  • Set training as True.
  • Set the batch_stats as mutable.
  • Set the random number for the DropOut.
Press + to interact
import optax
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=2)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_loss(params, batch_stats, images, labels):
logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats}, images, training=True, rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
loss = cross_entropy_loss(logits=logits, labels=labels)
return loss, (logits, batch_stats)

In the code above:

  • Line 1: We import the optax library to calculate the loss.

  • Lines 3–5: We define the cross_entropy_loss() function that receives two parameters: logits are the output of the model and labels are the actual labels. Inside this function, we call the one_hot() method of the jax.nn module to get the one-hot ...