...

/

Activation Functions in JAX

Activation Functions in JAX

Learn about various activation functions used to train models in JAX.

Overview

Activation functions are applied in neural networks to ensure the network outputs the desired result. The activation function caps the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an item belonging to either of the two classes.

Press + to interact
Neural network
Neural network

However, in a regression problem, we want the numerical prediction of a quantity, for example, the price of an item. We should, therefore, choose an appropriate activation function for the problem being solved. Let’s look at common activation functions in JAX and Flax.

ReLU

The Rectified Linear Unit (ReLU) activation function is primarily used in the hidden layers of neural networks to ensure non-linearity. The function caps all outputs to zero and above. Outputs below zero are returned as zero, while numbers above zero are returned as they are. This ensures that there are no negative numbers in the network.

Press + to interact
Graph of the ReLU activation function
Graph of the ReLU activation function

Let’s understand how to apply the ReLU activation function in the following code snippet:

Press + to interact
import flax
from flax import linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
x = nn.log_softmax(x)
return x

In lines 7, 10, and 14, we apply the ReLU activation ...