Chex

This lesson will introduce Chex.

Testing a numerical computation program can be tricky. This is especially the case when using JAX due to the parallel processing over GPU/TPU. The JAX ecosystem provides a library for it. Chex is a useful library with utilities like:

  • Assertions.
  • Debugging transformations (like vmap or pmap).
  • Testing code across JIT and non-JIT versions.

Assertions

Traditional PyType annotations do not support DeviceArray size or shape, so Chex provides assertions of its own.

Primitives

By using assert_shape() and assert_rank(), we can validate both the shape and dimension of a given JAX array.

Press + to interact
import chex
from chex import assert_shape, assert_rank
x = jnp.ones((5,5))
y = jnp.ones((2,5,3,4))
print(assert_shape(x, (5, 5)))
#print(assert_shape(x,[2, 4])) #will throw error due to incosistent shapes
print(assert_rank(y,4))

We can also validate the equal shapes directly using assert_equal_shape() and can use assert_type() to verify datatypes ...