What is a Hessian matrix?

Hessian matrix is crucial in machine learning and data science since it’s used to optimize functions. Hessian matrix is widely used in neural networks and other models.

If we take the second-order derivative of f:RnRf:R^n\to R, the resultant matrix is called a Hessian matrix. Since the derivative of a derivative is commutativedoesn’t depend on the order, Hessian matrices are symmetric.

Hf=[2fx122fx1x22fx1xn2fx2x12fx222fx2xn2fxnx12fxnx22fxn2] H_f= \begin{bmatrix} \dfrac{\partial^2 f}{\partial x_1^2} & \dfrac{\partial^2 f}{\partial x_1\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_1\,\partial x_n} \\[2.2ex] \dfrac{\partial^2 f}{\partial x_2\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_2^2} & \cdots & \dfrac{\partial^2 f}{\partial x_2\,\partial x_n} \\[2.2ex] \vdots & \vdots & \ddots & \vdots \\[2.2ex] \dfrac{\partial^2 f}{\partial x_n\,\partial x_1} & \dfrac{\partial^2 f}{\partial x_n\,\partial x_2} & \cdots & \dfrac{\partial^2 f}{\partial x_n^2} \end{bmatrix}

A Hessian matrix has several uses, including the following:

  • 2nd2^{nd} order optimization
  • Testing for critical points

Code

We can evaluate a Hessian in several numerical computing libraries. Following are the respective functions/syntax of some of them:

  • PyTorch: torch.autograd.functional.hessian() is used to calculate a hessian matrix in PyTorch. To learn more, refer to the official documentation.

  • JAX: jax.hessian() is used for the calculation of the Hessian matrix. The official documentation can help understand its internal implementation.
    Limiting ourselves to only JAX here, we can calculate a Hessian matrix directly using the jax.hessian().

    As an example, take a function: f(x,y)=3x36y2+3f(x,y) = 3x^3-6y^2+3

import jax
import jax.numpy as jnp
def F(x):
return 3*x[0]*x[0]*x[0]-6*x[1]*x[1]+3
hessian_x = jax.hessian(F)
print(type(hessian_x)) # a function that we will use later

The code above calls the respective libraries and jax.hessian() to calculate the Hessian matrix.

Having initialized hessian_x, now we are in a position to evaluate it at any value.

Like any other linear algebra function, hessian() also works on vector-valued inputs, which means that we need to convert the (x,y)(x,y) pair into a vector:

X=[xy] X = \begin{bmatrix} x \\ \\y \end{bmatrix}

before passing as an input.

X=jnp.array([1.2,3.4])
print("Hessian evaluation at (1.2,3.4)")
print(hessian_x(X))
Y=jnp.array([-1.0,1.0])
print("Hessian evaluation at (-1,1)")
print(hessian_x(Y))

We can evaluate the Hessian at any point like X and Y above.

Test for critical points

As mentioned at the start, we can use a Hessian to test the critical points in a pretty simple way.

If a Hessian is as follows:

  • Positive semidefinite: The point is a global minima.
  • Negative semidefinite (for exa,ple, all eigenvalues are negative): The point is a global maxima.
  • IndefiniteA square matrix having some eigenvalues positive and others negative: It is a saddle point (_and can be troublesome.
from jax.numpy import linalg #used for checking matrix definiteness
def TestCriticalPoint(hess, x):
if(jnp.all(linalg.eigvals(hess(x)) > 0)):
print("Point is local minima")
elif(jnp.all(linalg.eigvals(hess(x)) < 0)):
print("Point is local maxima")
else:
print("Its a Saddle point")
X=jnp.array([1.2,3.4])
Y=jnp.array([-1.0,1.0])
print("----Testing for X-----")
TestCriticalPoint(hessian_x,X)
print("----Testing for Y-----")
TestCriticalPoint(hessian_x,Y)

The code above takes a couple of arrays and checks whether the given points (represented by the arrays) are minima, maxima, or saddle points.

Note: The Hessian matrix's use in 2nd2^{nd} order methods is less frequent due to memory and computational requirements.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved