Computing and Monitoring Loss in JAX
Learn about various ways to calculate and monitor loss in models using JAX.
We'll cover the following...
Computing loss with JAX metrics
JAX metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics. For example, here is how we use the library to compute the cross entropy loss.
import jax_metrics as jmcrossentropy = jm.losses.Crossentropy()logits = jnp.array([0.50, 0.60, 0.70, 0.30, 0.25])labels = jnp.array([0.0, 1.0, 1.0, 0.0, 0.0])print(crossentropy(target=labels, preds=logits))print(jm.losses.crossentropy(target=labels, preds=logits))
In the code above:
- Line 1: We import the - jax_metricslibrary to calculate the cross entropy loss.
- Line 2: We create an instance of the - Crossentropy()loss function to compute the loss.
- Lines 4–5: We define two JAX arrays: - logitsand- labels.
- Line 7: We compute the cross entropy by calling the - crossentropy()function.
- Line 8: We compute the cross entropy by calling the - jm.losses.crossentropy()method. It is an alternative syntax to compute the loss.
Here ...