...

/

Distributed Training with JAX and Flax

Distributed Training with JAX and Flax

Learn about distributed training in JAX and Flax.

Training models on accelerators with JAX and Flax differ slightly from training with CPUs. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports TPU and GPU accelerators.

This lesson will focus on training models with Flax and JAX using GPUs and TPUs.

Press + to interact
Model parallelism
Model parallelism

Prior to training, it’s important to process the data and create a training state, which was covered in the earlier lesson.

Create training state

We now need to create parallel versions of our functions. Parallelization in JAX is done using the pmap function. pmap compiles a function with XLA and executes it on multiple devices.

Press + to interact
from flax.training import train_state
import optax
import functools
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In the code above:

  • Lines 1–3: We import the train_state from the flax.training module, optax, and functools libraries.

  • Line 5: We apply the @functools.partial() decorator with the jax.pmap argument for parallel execution of the create_train_state() function. We set the static_broadcasted_argnums=(1, 2) to broadcast the learning_rate and momentum as static values.

  • Lines 6–11: We define the create_train_state() function that creates the initial state for model training. This function takes three arguments: rng is the random number generator key and learning_rate and momentum are the parameters of the optimizer. Inside this function:

    • Lines 8–9: We create an instance cnn of the CNN class and get the initial model parameters params by calling the init() method of cnn. This method takes the random number generator key and a dummy input image of the JAX array of ones.

    • Line 10: We define a stochastic gradient descent optimizer with the provided learning rate and momentum.

    • Line 11: We create and return the train state by calling the create() method of the train_state.TrainState module. This method takes ...