...

/

Just-in-Time Compilation

Just-in-Time Compilation

Learn about ways to accelerate functions in JAX.

How fast is JAX?

JAX uses asynchronous dispatch, meaning it does not wait for computation to complete to give control back to the Python program. Therefore, when we perform an execution, JAX will return a future. JAX forces Python to wait for the execution when we want to print the output or convert the result to a NumPy array.

Therefore, if we want to compute the time of execution of a program, we’ll have to convert the result to a NumPy array using block_until_ready() to wait for the execution to complete. Generally speaking, NumPy will outperform JAX on the CPU, but JAX will outperform NumPy on accelerators and when using jitted functions.

Using jit() to speed up functions

The jit() method performs just-in-time compilationJust-in-time (JIT) compilation is a method to compile the code at runtime. It is also called dynamic translation. As a result, the code execution is sped up. with XLA. The jax.jit() method expects a pure function. Any side effects in the function will only be executed once. Let’s create a pure function and time its execution time without jit().

Press + to interact
def test_fn(sample_rate=3000,frequency=3):
x = jnp.arange(sample_rate)
y = np.sin(2*jnp.pi*frequency * (frequency/sample_rate))
return jnp.dot(x,y)
start_time = time.time()
x = test_fn()
print("--- %s seconds ---" % (time.time() - start_time))

In the code above:

  • Lines 1–4: We define a function, ...