Search⌘ K

Challenge: Optax

Explore how to implement an Adam optimizer using Optax in JAX by simplifying linear regression. Learn to initialize, compute gradients, and update parameters effectively with auto-vectorization support. This lesson guides you through creating a working optimizer step-by-step within the JAX ecosystem.

Time to switch back to the implementation side of things. As we mentioned earlier, implementing an optimizer in Optax is pretty straightforward. Let’s elaborate on those points in this guided challenge.

Prerequisites

We’ll use the linear regression example from the earlier challenge, but we’ll make it simpler by reducing it to a linear function (bb will be 00 ...