Stochastic Gradient Descent
Learn about SGD-based optimizers in JAX and Flax.
We'll cover the following...
SGD implements stochastic gradient descent with support for momentum and
Press + to interact
Let’s understand how to use SGD in the following playground:
Press + to interact
import optaxseed = random.PRNGKey(0)learning_rate = jnp.array(1/1e4)model = CNN()weights = model.init(seed, X_train[:5])optimizer = optax.sgd(learning_rate=learning_rate) # Initialize SGD as Optimizeroptimizer_state = optimizer.init(weights) # Optmizer state
In the code above: