What is the vanishing gradient problem?

Share

In gradient-based learning algorithms, we use gradients to learn the weights of a neural network. It works like a chain reaction as the gradients closer to the output layers are multiplied with the gradients of the layers closer to the input layers. These gradients are used to update the weights of the neural network.

If the gradients are small, the multiplication of these gradients will become so small that it will be close to zero. This results in the model being unable to learn, and its behavior becomes unstable. This problem is called the vanishing gradient problem.

Example

Here, we'll take an overly simplified example to understand this vanishing effect. Suppose we have a 20-layer neural network, and each layer has only one neuron in each layer.

20-Layer Neural Network

Here, xx is our 1-dimensional input to the neural network. It generates y^\hat{y} as the predicted output. The weights are represented by wiw_i. It uses some activation functions aia_i in the ithi^{th} layer.

In the feed-forward step, we calculate the output of neurons one by one. The output of the first neuron is calculated as:

This output will be fed as an input to the neuron in the second layer and so on.

As we can observe, chaining occurs with the outputs of the neurons. And if we calculate the output with respect to the very first weight, it would be very tedious and hard to keep track of.

When the gradient with respect to weight is high, the weight would see a greater change. For the sake of understanding, let's say that the value of all the weights is the same (0.170.17), and the input value is 11. Then the resulting value would be:

This is a very small value, and if there were 5050 layers in the neural network, this would be even smaller.

Causes of vanishing gradients

The effect of the small gradient value is smaller and smaller changes in weights, and eventually, the neural network stops training at all. But we would like to know why it gets small in the first place.

The vanishing gradient is a very common problem caused by the use of the sigmoid function. Even though sigmoid is used a lot, it is prone to the vanishing gradient problem. Especially when the depth of the neural network is increasing. The mathematical formula of the sigmoid function is given below:

Here, xx is the input. The range of the sigmoid function σ\sigma is [0,1][0, 1]. In comparison with some other activation functions, it is pretty low.

Comparison of sigmoid with tanh and ReLU
Comparison of sigmoid with tanh and ReLU

Solutions

Some common ways to counter the vanishing gradient are as follows:

  • Use residual blocks: ResNets implement residual blocks, which is very effective in countering the vanishing gradient problem.

  • Use a different activation function: As shown in the figure above, some activation functions are better than sigmoid. To learn more, refer to this link.

  • Use careful weight initialization: Normally, we randomly initialize the weights for a neural network. Some weight initialization techniques like initialization and Xavier initialization can ensure that the weights remain close to 11.

Copyright ©2024 Educative, Inc. All rights reserved