Machine Learning with JAX
Learn about various machine learning functionalities available in the JAX library.
We'll cover the following...
Taking derivatives with grad()
Computing derivatives in JAX is done using jax.grad.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
Lines 1–3: We apply the
@jax.jitdecorator to thesum_logistic()function.Line 5: We generate a JAX array of values from zero to five and store it in
x_small.Line 6: We use the
jax.grad()function to calculate the derivative of thesum_logistics()function with respect to its input. We store the derivative function to thederivative_fn.Lines 7–8: We print the original JAX array,
x_small, and the derivative of it usingderivative_fn.
The grad function has a has_aux argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic, has_aux=True)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
- Line 6: We pass a
Truevalue to thehas_auxargument to make sure that thesum_logistic()function returns the auxiliary data. - Line 8: We print the derivative of
x_smallusingderivative_fn. We can see the auxiliary data along with the derivative results in the output.
We can perform advanced automatic differentiation using jax.vjp() and jax.jvp().
Auto-vectorization with vmap
The vmap (vectorizing map) allows us to write a function that can be applied to a single data, and then vmap will map it to a batch of data. Without vmap, the solution would be to loop through the batches while applying the function. Using jit with for loops is a little complicated and may be slower.
seed = 98key = jax.random.PRNGKey(seed)mat = jax.random.normal(key, (150, 100))batched_x = jax.random.normal(key, (10, 100))def apply_matrix(v):return jnp.dot(mat, v)@jax.jitdef vmap_batched_apply_matrix(v_batched):return jax.vmap(apply_matrix)(v_batched)print('Auto-vectorized with vmap')start_time = time.time()print(vmap_batched_apply_matrix(batched_x).block_until_ready())print("--- Execution time: %s seconds ---" % (time.time() - start_time))
In the code above:
Lines 1–4: We generate random matrices,
mat, with dimensions of...