Activation Functions in JAX
Learn about various activation functions used to train models in JAX.
Overview
Activation functions are applied in neural networks to ensure the network outputs the desired result. The activation function caps the output within a specific range. For instance, when solving a binary classification problem, the outcome should be a number between 0 and 1. This indicates the probability of an item belonging to either of the two classes.
However, in a regression problem, we want the numerical prediction of a quantity, for example, the price of an item. We should, therefore, choose an appropriate activation function for the problem being solved. Let’s look at common activation functions in JAX and Flax.
ReLU
The Rectified Linear Unit (ReLU) activation function is primarily used in the hidden layers of neural networks to ensure non-linearity. The function caps all outputs to zero and above. Outputs below zero are returned as zero, while numbers above zero are returned as they are. This ensures that there are no negative numbers in the network.
Let’s understand how to apply the ReLU activation function in the following code snippet:
import flaxfrom flax import linen as nnclass CNN(nn.Module):@nn.compactdef __call__(self, x):x = nn.Conv(features=32, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = nn.Conv(features=64, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = x.reshape((x.shape[0], -1))x = nn.Dense(features=256)(x)x = nn.relu(x)x = nn.Dense(features=2)(x)x = nn.log_softmax(x)return x
In lines 7, 10, and 14, we apply the ReLU activation ...