Skip Connections
Learn about skip connections and the problem they solve.
If you were trying to train a neural network back in 2014, you would have definitely observed the so-called vanishing gradient problem. In simple terms: you are behind the screen checking the training process of your network, and all you see is that the training loss stops decreasing but is still far away from the desired value. You check all your code lines to see if something was wrong all night and you find no clue.
The update rule and the vanishing gradient problem
Let’s revisit ourselves the update rule of gradient descent without momentum, given L to be the loss function and to be the learning rate:
where
You try to update the parameters by changing them with a small amount that was calculated based on the gradient. For instance, let’s suppose that for an early layer the average gradient is 1e-15 (ΔL/δw). Given a learning rate of 1e-4 ( λ in the equation), you basically change the layer parameters by the product of the referenced quantities, which is 1e-19 ( ). As a result, you don’t actually observe any change in the model while training your network. This is how you can observe the vanishing gradient problem.
Skip connections for the win
At present, skip connection is a standard module in many CNN architectures. By using a skip connection, we provide an alternative path for the gradient.
It is experimentally validated that these additional paths are often beneficial for the convergence of the model.
Skip connections, as the name suggests, skip some layer in the neural network and feed the output of one layer as the input to the next layers, instead of just the next one.
As explained in the second chapter, using the chain rule, we keep multiplying terms with the error gradient as we go backwards. However, in the long chain of multiplication, if we multiply many things together that are less than one, the resulting gradient will be very small.
Thus, the gradient becomes very small as we approach the earlier layers in a deep network. In some cases, the gradient becomes zero, meaning that we do not update the early layers at all.
In general, there are two fundamental ways that one could use skip connections through different non-sequential layers:
a) Addition, as in residual architectures
b) Concatenation, as in densely connected architectures
ResNet: skip connections via addition
The core idea is to backpropagate through the identity function by just using vector addition. The gradient would then simply be multiplied by one and its value will be maintained in the earlier layers. This is the main idea behind Residual Networks (ResNets): they stack these skip residual blocks together. We use an identity function to preserve the gradient.
Get hands-on with 1300+ tech skills courses.