Computing and Monitoring Loss in JAX
Learn about various ways to calculate and monitor loss in models using JAX.
We'll cover the following...
Computing loss with JAX metrics
JAX metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics. For example, here is how we use the library to compute the cross entropy loss.
Press + to interact
import jax_metrics as jmcrossentropy = jm.losses.Crossentropy()logits = jnp.array([0.50, 0.60, 0.70, 0.30, 0.25])labels = jnp.array([0.0, 1.0, 1.0, 0.0, 0.0])print(crossentropy(target=labels, preds=logits))print(jm.losses.crossentropy(target=labels, preds=logits))
In the code above:
Line 1: We import the
jax_metrics
library to calculate the cross entropy loss.Line 2: We create an instance of the
Crossentropy()
loss function to compute the loss.Lines 4–5: We define two JAX arrays:
logits
andlabels
.Line 7: We compute the cross entropy by calling the
crossentropy()
function.Line 8: We compute the cross entropy by calling the
jm.losses.crossentropy()
method. It is an alternative syntax to compute the loss. ...
Access this course and 1400+ top-rated courses and projects.