JAX and NumPy

This lesson will explore the relationship between JAX and NumPy, since JAX can be treated as a parallelized version of NumPy.

This lesson assumes a certain level of familiarity with NumPy. Let’s first start with how normal NumPy and JAX relate.

JAX Numpy

JAX has its own variant of NumPy, which we can import as:

Press + to interact
import jax.numpy

One might worry that we have to re-learn a new NumPy from scratch, but luckily the syntax of NumPy and JAX is the same. For ...