...

/

Stochastic Gradient Descent

Stochastic Gradient Descent

Learn about SGD-based optimizers in JAX and Flax.

We'll cover the following...

SGD implements stochastic gradient descent with support for momentum and Nesterov accelerationNesterov acceleration refers to a method of accelerating the convergence of iterative optimization algorithms commonly used in machine learning.. Momentum makes obtaining optimal model weights faster by accelerating gradient descent in a certain direction.

Press + to interact
Gradient function
Gradient function

Let’s understand how to use SGD in the following playground:

Press + to interact
import optax
seed = random.PRNGKey(0)
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.sgd(learning_rate=learning_rate) # Initialize SGD as Optimizer
optimizer_state = optimizer.init(weights) # Optmizer state

In the code above:

...