...

/

Computing and Monitoring Loss in JAX

Computing and Monitoring Loss in JAX

Learn about various ways to calculate and monitor loss in models using JAX.

Computing loss with JAX metrics

JAX metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics. For example, here is how we use the library to compute the cross entropy loss.

Press + to interact
import jax_metrics as jm
crossentropy = jm.losses.Crossentropy()
logits = jnp.array([0.50, 0.60, 0.70, 0.30, 0.25])
labels = jnp.array([0.0, 1.0, 1.0, 0.0, 0.0])
print(crossentropy(target=labels, preds=logits))
print(jm.losses.crossentropy(target=labels, preds=logits))

In the code above:

  • Line 1: We import the jax_metrics library to calculate the cross entropy loss.

  • Line 2: We create an instance of the Crossentropy() loss function to compute the loss.

  • Lines 4–5: We define two JAX arrays: logits and labels.

  • Line 7: We compute the cross entropy by calling the crossentropy() function.

  • Line 8: We compute the cross entropy by calling the jm.losses.crossentropy() method. It is an alternative syntax to compute the loss. ...

Access this course and 1400+ top-rated courses and projects.