...

/

Solution: Loss and Activation function

Solution: Loss and Activation function

Let’s review the solution.

Solution 1: Implementation of loss functions

Here is the implementation of the following functions:

  1. Softmax cross-entropy
  2. Cosine similarity
  3. Huber loss
  4. CELU
  5. Softplus
Press + to interact
import jax.numpy as jnp
import optax
logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
print("Softmax function: ", optax.softmax_cross_entropy(logits,labels))
print("Cosine similarity: ", optax.cosine_similarity(logits,labels,epsilon=0.5))
print("Huber Loss: ", optax.huber_loss(logits,labels))
print("L2 loss: ", optax.l2_loss(logits,labels))
print("CELU: ", jax.nn.celu(logits))
print("Softplus: ", jax.nn.softplus(logits))

Let’s review the code line-by-line:

  • Lines 1–3: We import all the necessary libraries. JAX or, more specifically, jax.nn, provides the necessary activation functions, while optax contains all the optimizer and loss functions.

  • Lines 4–5: We define the JAX arrays: logits and labels.

  • Lines 8–13: We call the ...