Introduction to Optimizers

Learn about various types of optimizers used in JAX.

What are optimizers?

Optimizers are applied while training neural networks for reducing the errors between the true and predicted values. This optimization is done via gradient descent. Gradient descent adjusts errors in the network through a cost function. In JAX, optimizers are applied from the Optax library.

Press + to interact

Optimizers can be classified into two broad categories:

  • Adaptive: Adam, Adagrad, Adadelta, and RMSprop.
  • Stochastic gradient descent (SGD): SGD with momentum, heavy-ball method (HB), and Nesterov Accelerated Gradient (NAG).

Adaptive vs. SGD-based optimizers

When performing optimization, adaptive optimizers start with large update steps but reduce the step size as they get close to the global minimum. This ensures that they don’t miss the global minimum.

Adaptive optimizers such as Adam are quite common because they converge faster, but they may have poor generalization.

SGD-based optimizers apply a global learning rate on all parameters, while adaptive optimizers calculate a learning rate for each parameter.