Search⌘ K

Stochastic Gradient Descent

Explore the implementation and use of stochastic gradient descent with momentum and Nesterov acceleration in JAX and Flax. Understand various optimizers including Noisy SGD, Optimistic Gradient Descent, RMSProp, and Yogi, and learn how to apply them to improve model training and convergence.

We'll cover the following...

SGD implements stochastic gradient descent with support for momentum and Nesterov accelerationNesterov acceleration refers to a method of accelerating the convergence of iterative optimization algorithms commonly used in machine learning.. Momentum makes obtaining optimal model weights faster by accelerating gradient descent in a certain direction.

Gradient function
Gradient function

Let’s understand how to use SGD in the following playground:

Python 3.8
import optax
seed = random.PRNGKey(0)
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.sgd(learning_rate=learning_rate) # Initialize SGD as Optimizer
optimizer_state = optimizer.init(weights) # Optmizer state

In the code above:

...