Hessian matrix is crucial in machine learning and data science since it’s used to optimize functions. Hessian matrix is widely used in neural networks and other models.
If we take the second-order derivative of , the resultant matrix is called a Hessian matrix. Since the derivative of a derivative is
A Hessian matrix has several uses, including the following:
We can evaluate a Hessian in several numerical computing libraries. Following are the respective functions/syntax of some of them:
PyTorch: torch.autograd.functional.hessian()
is used to calculate a hessian matrix in PyTorch. To learn more, refer to the official documentation.
JAX: jax.hessian()
is used for the calculation of the Hessian matrix. The official documentation can help understand its internal implementation.
Limiting ourselves to only JAX here, we can calculate a Hessian matrix directly using the jax.hessian()
.
As an example, take a function:
import jaximport jax.numpy as jnpdef F(x):return 3*x[0]*x[0]*x[0]-6*x[1]*x[1]+3hessian_x = jax.hessian(F)print(type(hessian_x)) # a function that we will use later
The code above calls the respective libraries and jax.hessian()
to calculate the Hessian matrix.
Having initialized hessian_x
, now we are in a position to evaluate it at any value.
Like any other linear algebra function, hessian()
also works on vector-valued inputs, which means that we need to convert the pair into a vector:
before passing as an input.
X=jnp.array([1.2,3.4])print("Hessian evaluation at (1.2,3.4)")print(hessian_x(X))Y=jnp.array([-1.0,1.0])print("Hessian evaluation at (-1,1)")print(hessian_x(Y))
We can evaluate the Hessian at any point like X
and Y
above.
As mentioned at the start, we can use a Hessian to test the critical points in a pretty simple way.
If a Hessian is as follows:
from jax.numpy import linalg #used for checking matrix definitenessdef TestCriticalPoint(hess, x):if(jnp.all(linalg.eigvals(hess(x)) > 0)):print("Point is local minima")elif(jnp.all(linalg.eigvals(hess(x)) < 0)):print("Point is local maxima")else:print("Its a Saddle point")X=jnp.array([1.2,3.4])Y=jnp.array([-1.0,1.0])print("----Testing for X-----")TestCriticalPoint(hessian_x,X)print("----Testing for Y-----")TestCriticalPoint(hessian_x,Y)
The code above takes a couple of arrays and checks whether the given points (represented by the arrays) are minima, maxima, or saddle points.
Note: The Hessian matrix's use in
order methods is less frequent due to memory and computational requirements.
Free Resources