Just-in-Time Compilation
Learn about ways to accelerate functions in JAX.
We'll cover the following...
How fast is JAX?
JAX uses asynchronous dispatch, meaning it does not wait for computation to complete to give control back to the Python program. Therefore, when we perform an execution, JAX will return a future. JAX forces Python to wait for the execution when we want to print the output or convert the result to a NumPy array.
Therefore, if we want to compute the time of execution of a program, we’ll have to convert the result to a NumPy array using block_until_ready()
to wait for the execution to complete. Generally speaking, NumPy will outperform JAX on the CPU, but JAX will outperform NumPy on accelerators and when using jitted functions.
Using jit()
to speed up functions
The jit()
method performs jax.jit()
method expects a pure function. Any side effects in the function will only be executed once. Let’s create a pure function and time its execution time without jit()
.
def test_fn(sample_rate=3000,frequency=3):x = jnp.arange(sample_rate)y = np.sin(2*jnp.pi*frequency * (frequency/sample_rate))return jnp.dot(x,y)start_time = time.time()x = test_fn()print("--- %s seconds ---" % (time.time() - start_time))
In the code above:
Lines 1–4: We define a function, ...