Solution: Loss and Activation function
Let’s review the solution.
Solution 1: Implementation of loss functions
Here is the implementation of the following functions:
- Softmax cross-entropy
- Cosine similarity
- Huber loss
- CELU
- Softplus
Press + to interact
import jax.numpy as jnpimport optaxlogits = jnp.array([0.50,0.60,0.70,0.30,0.25])labels = jnp.array([0.20,0.30,0.10,0.20,0.2])print("Softmax function: ", optax.softmax_cross_entropy(logits,labels))print("Cosine similarity: ", optax.cosine_similarity(logits,labels,epsilon=0.5))print("Huber Loss: ", optax.huber_loss(logits,labels))print("L2 loss: ", optax.l2_loss(logits,labels))print("CELU: ", jax.nn.celu(logits))print("Softplus: ", jax.nn.softplus(logits))
Let’s review the code line-by-line:
-
Lines 1–3: We import all the necessary libraries. JAX or, more specifically,
jax.nn
, provides the necessary activation functions, while optax contains all the optimizer and loss functions. -
Lines 4–5: We define the JAX arrays:
logits
andlabels
. -
Lines 8–13: We call the ...