...

/

Machine Learning with JAX

Machine Learning with JAX

Learn about various machine learning functionalities available in the JAX library.

Taking derivatives with grad()

Computing derivatives in JAX is done using jax.grad.

Press + to interact
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic)
print("Original: ", x_small)
print("Derivative: ", derivative_fn(x_small))

In the code above:

  • Lines 1–3: We apply the @jax.jit decorator to the sum_logistic() function.

  • Line 5: We generate a JAX array of values from zero to five and store it in x_small.

  • Line 6: We use the jax.grad() function to calculate the derivative of the sum_logistics() function with respect to its input. We store the derivative function to the derivative_fn.

  • Lines 7–8: We print the original JAX array, x_small, and the derivative of it using derivative_fn.

The grad function has a has_aux argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.

Press + to interact
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print("Original: ", x_small)
print("Derivative: ", derivative_fn(x_small))

In the code above:

  • Line 6: We pass a True value to the has_aux argument to make sure that the sum_logistic() function returns the auxiliary data.
  • Line 8: We print the derivative of x_small using derivative_fn. We can see the auxiliary data along with the derivative results in the output.

We can perform advanced automatic differentiation using jax.vjp() and jax.jvp().

Auto-vectorization with vmap

The vmap (vectorizing map) allows us to write a function that can be applied to a single data, and then vmap will map it to a batch of data. Without vmap, the solution would be to loop through the ...

Access this course and 1400+ top-rated courses and projects.