Auto-differentiation
This lesson will cover auto-differentiation, the core component of deep learning libraries.
Background
First pioneered by the seminal work of Rumelhart and Hinton in 1986, the majority of current machine learning optimization methods use derivatives. So, there is a pressing need for their efficient calculation.
Manual calculations
Most of the early machine learning researchers and scientists for example, Bottou, 1998 for Stochastic Gradient Descent had to go through a slow, laborious process of manual calculation of analytical derivatives, which is prone to error.
Using computer program
Programming-based solutions are less laborious, but calculating these derivatives in a program can also be tricky. We can categorize them into three paradigms:
- Symbolic differentiation
- Numeric differentiation
- Auto differentiation
The first and second methods are prone to errors, including:
- Calculating higher-derivatives is tricky due to long and complex expressions for symbolic differentiation and rounding-off errors meaning less accurate results in numeric differentiation.
- Numeric differentiation uses discretization, which results in loss of accuracy.
- Symbolic differentiation can lead to inefficient code.
- Both are slow to calculate the partial derivatives, a key feature of gradient-based optimization algorithms.
Automatic differentiation
Automatic differentiation (also known as autodiff) addresses all the issues above and is a key feature of modern ML/DL libraries.
Chain rule
Autodiff centers around the concept of the chain rule, the fundamental rule in calculus used to calculate derivatives of the composed functions.
For example,
and,
Obviously, differentiating y with respect to w (i.e.) is not directly possible. Instead, it will be calculated indirectly using the chain rule:
and,
There are a couple of ways to calculate the products using the chain rule.
Forward accumulation
In forward accumulation, we fix the independent variable and compute gradients recursively.
Reverse accumulation
Usually, in deep learning (i.e. backpropagation), we use reverse accumulation in all the major frameworks like PyTorch or Tensorflow.
JAX is even better at performing both types of accumulation. The choice to use forward or reverse accumulation usually depends on the number of features, but reverse accumulation is generally the default method in deep learning.
Get hands-on with 1400+ tech skills courses.