Machine Learning with JAX
Learn about various machine learning functionalities available in the JAX library.
Taking derivatives with grad()
Computing derivatives in JAX is done using jax.grad
.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
Lines 1–3: We apply the
@jax.jit
decorator to thesum_logistic()
function.Line 5: We generate a JAX array of values from zero to five and store it in
x_small
.Line 6: We use the
jax.grad()
function to calculate the derivative of thesum_logistics()
function with respect to its input. We store the derivative function to thederivative_fn
.Lines 7–8: We print the original JAX array,
x_small
, and the derivative of it usingderivative_fn
.
The grad
function has a has_aux
argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic, has_aux=True)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
- Line 6: We pass a
True
value to thehas_aux
argument to make sure that thesum_logistic()
function returns the auxiliary data. - Line 8: We print the derivative of
x_small
usingderivative_fn
. We can see the auxiliary data along with the derivative results in the output.
We can perform advanced automatic differentiation using jax.vjp()
and jax.jvp()
.
Auto-vectorization with vmap
The vmap
(vectorizing map) allows us to write a function that can be applied to a single data, and then vmap
will map it to a batch of data. Without vmap
, the solution would be to loop through the ...