Applying BatchNorm and DropOut Layers in JAX and Flax
Learn how to apply BatchNorm and DropOut layers in JAX and Flax.
Jitting functions in Flax makes them faster but requires that the functions have no side effects. The fact that jitted functions can’t have side effects introduces a challenge when dealing with stateful items, such as model parameters, and stateful layers, such as batch normalization layers.
In this lesson, we’ll create a network with the BatchNorm and DropOut layers. After that, we’ll see how to deal with generating the random number for the DropOut layer and adding the batch statistics when training the network.
Get hands-on with 1300+ tech skills courses.