...

/

Adaptive Optimizers

Adaptive Optimizers

Learn about the adaptive optimizers used in JAX and Flax.

AdaBelief

AdaBelief works on the concept of “belief” in the current gradient direction. If it results in good performance, then that direction is trusted, and large updates are applied. Otherwise, it’s distrusted and the step size is reduced.

The authors of AdaBelief introduced the optimizer to:

  • Converge fast, as in adaptive methods.
  • Have good generalization like SGD.
  • Be stable during training.

Let’s look at a Flax training state that applies the AdaBelief optimizer.

Press + to interact
import optax
seed = random.PRNGKey(0)
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.adabelief(learning_rate=learning_rate) # Initialize AdaBelief Optimizer
optimizer_state = optimizer.init(weights) # Optmizer state

In the code above:

  • Line 1: We import the optax library for optimizers.

  • Lines 2–3: We define the random seed variable and learning rate for the CNN network.

  • Lines 5–6: We ...

Access this course and 1400+ top-rated courses and projects.