Applying BatchNorm and DropOut Layers in JAX and Flax
Learn how to apply BatchNorm and DropOut layers in JAX and Flax.
Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can’t have side effects introduces a challenge when dealing with stateful items, such as model parameters, and stateful layers, such as batch normalization layers.
In this lesson, we’ll create a network with the BatchNorm and DropOut layers. After that, we’ll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.
Prior to training and applying layers, it’s important to process the data which was covered previously.
Define Flax model with BatchNorm and DropOut
Let’s define the Flax network with the BatchNorm and DropOut layers. In the network, we introduce the training
variable to control when the batch stats should be updated. We ensure that they aren’t updated during testing.
In the BatchNorm
layer, we set use_running_average
to False
, meaning that the stats stored in batch_stats
will not be used, but batch stats of the input will be computed. The DropOut
layer takes the following rate:
- The rate dropout probability.
- Whether it’s
deterministic
: if it isdeterministic
, inputs are scaled and masked. Otherwise, they are not masked and returned as they are.
from flax import linen as nnclass CNN(nn.Module):@nn.compactdef __call__(self, x, training):x = nn.Conv(features=128, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))x = nn.Conv(features=64, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))x = nn.Conv(features=32, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))x = x.reshape((x.shape[0], -1))x = nn.Dense(features=256)(x)x = nn.Dense(features=128)(x)x = nn.BatchNorm(use_running_average=not training)(x)x = nn.Dropout(0.2, deterministic=not training)(x)x = nn.relu(x)x = nn.Dense(features=2)(x)x = nn.log_softmax(x)return x
In the code above:
Line 1: We import the
linen
module from theflax
library asnn
.Lines 3–23: We define the
CNN
class that inherits from the Flax modulenn.Module
. In this class, we define the forward pass__call__()
function. Inside the__call__()
function:Lines 6–14: We start the forward pass with the convolution layer with 128 filters and a kernel size of 3 by 3. We apply ReLU activation followed by average pooling with a window shape and stride of 2 by 2. Similarly, the same layers are added again with 64 and 32 filters of the convolution layer.
Lines 15–17: The output of the above layers is flattened and passes through two Dense layers with 256 and 128 features, respectively.
Lines 18–19: We apply the batch normalization and the dropout layer with a 20% of dropout rate.
Lines 20–23: We apply ReLU activation again and pass the output from the third Dense layer with 2 features for binary classification. Lastly, we apply the LogSoftmax activation function and return the output.
Create loss function
The next step is to create the loss function. When applying the model, we:
- Pass the
batch_stats
parameters. - Set
training
asTrue
. - Set the
batch_stats
as mutable. - Set the random number for the
DropOut
.
import optaxdef cross_entropy_loss(*, logits, labels):labels_onehot = jax.nn.one_hot(labels, num_classes=2)return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()def compute_loss(params, batch_stats, images, labels):logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats}, images, training=True, rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])loss = cross_entropy_loss(logits=logits, labels=labels)return loss, (logits, batch_stats)
In the code above:
Line 1: We import the
optax
library to calculate the loss.Lines 3–5: We define the
cross_entropy_loss()
function that receives two parameters:logits
are the output of the model andlabels
are the actual labels. Inside this function, we call theone_hot()
method of thejax.nn
module to get the one-hot ...