Adaptive Optimizers
Learn about the adaptive optimizers used in JAX and Flax.
AdaBelief
AdaBelief works on the concept of “belief” in the current gradient direction. If it results in good performance, then that direction is trusted, and large updates are applied. Otherwise, it’s distrusted and the step size is reduced.
The authors of AdaBelief introduced the optimizer to:
- Converge fast, as in adaptive methods.
- Have good generalization like SGD.
- Be stable during training.
Let’s look at a Flax training state that applies the AdaBelief optimizer.
Press + to interact
import optaxseed = random.PRNGKey(0)learning_rate = jnp.array(1/1e4)model = CNN()weights = model.init(seed, X_train[:5])optimizer = optax.adabelief(learning_rate=learning_rate) # Initialize AdaBelief Optimizeroptimizer_state = optimizer.init(weights) # Optmizer state
In the code above:
Line 1: We import the
optax
library for optimizers.Lines 2–3: We define the random
seed
variable and learning rate for the CNN network.Lines 5–6: We ...
Access this course and 1400+ top-rated courses and projects.