Pre-activation
Learn why batch normalization and pre-activation are important in ResNet.
We'll cover the following...
Chapter Goals:
- Learn about internal covariate shift and batch normalization
- Understand the difference between pre-activation and post-activation
- Implement the pre-activation function
A. Internal covariate shift
When training models with many weight layers, e.g. ResNet, a problem known as internal covariate shift can occur. To understand this problem, we first define a related problem known as covariate shift.
A covariate shift occurs when the input data's distribution changes and the model cannot handle the change properly. An example would be if a model was trained to classify between different dog breeds, with a training set of only images of brown dogs. Then if we test the model on images of yellow dogs, the performance may not be as good as we expected.
In this case, the model's original input distribution was limited to just brown dogs, and changing the input distribution to a different color of dogs introduced covariate shift.
An internal covariate shift is essentially just a covariate shift that happens between layers of a model. Since the input of one layer is the output of the previous layer, the input distribution for a layer is the same as the output distribution of the previous layer.
Because the output distribution of a layer depends on its weights, and the weights of a model are constantly being updated, each layer's output distribution will constantly change (though by incremental amounts). In a model without many layers, the incremental changes in layer distributions don't really have much impact. However, in models with many layers these incremental changes will eventually add up, and lead to internal covariate shift at deeper layers.
B. Batch normalization
The solution to internal covariate shift is batch normalization. Since internal covariate shift is caused by distribution changes between layers, we can remedy this by enforcing a fixed distribution to the inputs of each layer.
Batch normalization accomplishes this by subtracting the mean from the inputs and dividing by the standard deviation (i.e. square-root variance). This results in a standardized distribution (i.e. mean of 0 and variance of 1).
One thing to note is that batch normalization is applied across a specific dimension of the data. For CNNs, we apply it across the channels dimension, meaning we standardize the values for each channel of the input data. So the mean and variance are actually vectors of num_channels
values.
Below we show standardization of an input with 2 channels and a batch size of 3.
For some layers we might not want a standardized distribution of the inputs. Maybe we want a distribution with a different mean or variance. Luckily, batch normalization has two trainable variables, ...