From the Chain Rule to Backpropagation
Learn how to apply chain rule in backpropagation.
We'll cover the following...
Backpropagation is an application of the chain rule, it is one of the fundamental rules of calculus. The chain rule calculates the gradient of any node y with respect to any other node , we multiply the local gradient of all the nodes on the way back from to .
Let’s understand how the chain rule works on a couple of network-like structures: a simple one, and a more complicated one.
The chain rule on a simple network
Let’s look at this simple network-like structure:
This is not a neural network, because it does not have weights. Let’s borrow a term from computer science, and call it a computational graph. This graph has an input , followed by two operations: multiply by two and square. The output of the multiplication is called , and the output of the entire graph is called .
Now let’s say that we want to calculate ∂/∂, the gradient of with respect to . Intuitively, that gradient represents the impact of on . Whenever changes, also changes, and the gradient measures the amount of change. (If you find the gradientcomplex, review the Gradient Descent lesson).
For such a small graph, we could calculate ∂/∂ in a single step, by taking the derivative of with respect to . However, as we mentioned earlier, that derivation would become impractical for very large graphs. Instead, let’s calculate the gradient using the chain rule, which works for graphs of any size.
Here is how the chain rule works. To calculate ∂/∂:
- Walk the graph back from to .
- For each operation along the way, calculate its local gradient—the derivative of the operation’s output with respect to its input.
- Multiply all the local gradients together.
Let’s see how that process works in practice. In our case, the path back from to involves two operations:
- A square
- A multiplication by .
Let’s compile the local gradients of those two operations:
How do we know that ∂/∂ is , and ∂/∂ is ? Well, even though we use the chain rule, we must still compute the local gradients in an old-fashioned way, by taking derivatives by hand. However, don’t worry if you do not know how to take derivatives. We can always use libraries to do that. For now, we just have to understand the process.
Now that we have the local gradients, we can multiply them to get ∂/∂:
...