Chex
This lesson will introduce Chex.
We'll cover the following...
Testing a numerical computation program can be tricky. This is especially the case when using JAX due to the parallel processing over GPU/TPU. The JAX ecosystem provides a library for it. Chex is a useful library with utilities like:
- Assertions.
- Debugging transformations (like
vmap
orpmap
). - Testing code across JIT and non-JIT versions.
Assertions
Traditional PyType annotations do not support DeviceArray
size or shape, so Chex provides assertions of its own.
Primitives
By using assert_shape()
and assert_rank()
, we can validate both the shape and dimension of a given JAX array.
Press + to interact
import chexfrom chex import assert_shape, assert_rankx = jnp.ones((5,5))y = jnp.ones((2,5,3,4))print(assert_shape(x, (5, 5)))#print(assert_shape(x,[2, 4])) #will throw error due to incosistent shapesprint(assert_rank(y,4))
We can also validate the equal shapes directly using assert_equal_shape()
and can use assert_type()
to verify datatypes ...