Search⌘ K

Computing and Monitoring Loss in JAX

Explore how to compute different loss functions such as cross entropy, Huber loss, and mean squared error using the jax_metrics library. Understand methods to monitor training and validation loss in JAX neural networks, and learn how to detect and handle NaN errors for stable model training.

Computing loss with JAX metrics

JAX metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics. For example, here is how we use the library to compute the cross entropy loss.

Python 3.8
import jax_metrics as jm
crossentropy = jm.losses.Crossentropy()
logits = jnp.array([0.50, 0.60, 0.70, 0.30, 0.25])
labels = jnp.array([0.0, 1.0, 1.0, 0.0, 0.0])
print(crossentropy(target=labels, preds=logits))
print(jm.losses.crossentropy(target=labels, preds=logits))

In the code above:

  • Line 1: We import the jax_metrics library to calculate the cross entropy loss.

  • Line 2: We create an instance of the Crossentropy() loss function to compute the loss.

  • Lines 4–5: We define two JAX arrays: logits and labels.

  • Line 7: We compute the cross entropy by calling the crossentropy() function.

  • Line 8: We compute the cross entropy by calling the jm.losses.crossentropy() method. It is an alternative syntax to compute the loss.

Here ...