Multivariate Calculus
This lesson will introduce multivariate calculus in JAX.
We'll cover the following
We can build on our linear algebra refresher to use auto differentiation for vectors and matrices as well. These are instrumental in neural networks when we apply gradient-based optimization on whole matrices.
Gradient
If we have a multivariate function (a function of more than one variable), , we calculate its derivative by taking the partial derivative with respect to every (input/independent) variable.
For example, if we have:
it’s derivative, represented by ∇, will be calculated as:
Generally, we can define this as:
The above example can be calculated using grad()
, as below.
Note: Since the first and third terms in the example are constants, 4 and -1, we can try any permutation of the input vector’s values to confirm the same output.
Get hands-on with 1400+ tech skills courses.