...

/

Multivariate Calculus

Multivariate Calculus

This lesson will introduce multivariate calculus in JAX.

We can build on our linear algebra refresher to use auto differentiation for vectors and matrices as well. These are instrumental in neural networks when we apply gradient-based optimization on whole matrices.

Gradient

If we have a multivariate function (a function of more than one variable), f:RnRf:R^n \to R, we calculate its derivative by taking the partial derivative with respect to every (input/independent) variable.

For example, if we have:

f(x,y,z)=4x+3y2zf(x,y,z) = 4x+3y^2-z

it’s derivative, represented by , will be calculated as:

f(x,y,z)=[δf(x,y,z)δxδf(x,y,z)δyδf(x,y,z)δz]=[46y1] \nabla f(x,y,z) = \begin{bmatrix} \frac{\delta f(x,y,z)}{\delta x} \\ \\\frac{\delta f(x,y,z)}{\delta y} \\\\\frac{\delta f(x,y,z)}{\delta z} \end{bmatrix} = \begin{bmatrix} 4 \\ \\6y \\\\ -1 \end{bmatrix} ...