Distributed Training with JAX and Flax
Learn about distributed training in JAX and Flax.
We'll cover the following...
Training models on accelerators with JAX and Flax differ slightly from training with CPUs. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports TPU and GPU accelerators.
This lesson will focus on training models with Flax and JAX using GPUs and TPUs.
Prior to training, it’s important to process the data and create a training state, which was covered in the earlier lesson.
Create training state
We now need to create parallel versions of our functions. Parallelization in JAX is done using the pmap
function. pmap
compiles a function with XLA and executes it on multiple devices.
from flax.training import train_stateimport optaximport functools@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))def create_train_state(rng, learning_rate, momentum):"""Creates initial `TrainState`."""cnn = CNN()params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params']tx = optax.sgd(learning_rate, momentum)return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
In the code above:
Lines 1–3: We import the
train_state
from theflax.training
module,optax
, andfunctools
libraries.Line 5: We apply the
@functools.partial()
decorator with thejax.pmap
argument for parallel execution of thecreate_train_state()
function. We set thestatic_broadcasted_argnums=(1, 2)
to broadcast thelearning_rate
andmomentum
as static values.Lines 6–11: We define the
create_train_state()
function that creates the initial state for model training. This function takes three arguments:rng
is the random number generator key andlearning_rate
andmomentum
are the parameters of the optimizer. Inside this function:Lines 8–9: We create an instance
cnn
of theCNN
class and get the initial model parametersparams
by calling theinit()
method ofcnn
. This method takes the random number generator key and a dummy input image of the JAX array of ones.Line 10: We define a stochastic gradient descent optimizer with the provided learning rate and momentum.
Line 11: We create and return the train state by calling the
create()
method of thetrain_state.TrainState
module. This method takes ...