...

/

Asynchronous Dispatch

Asynchronous Dispatch

This lesson will introduce asynchronous dispatch.

We'll cover the following...

Introductory example

We have already talked about the JIT’s usefulness. Now let’s take a look at the difference in speed between a JIT and a normal function:

Press + to interact
#Normal Function
def FuncA(x,y):
return np.asarray(np.dot(x,y))
#And JIT version
@jit
def FuncB(x,y):
return jnp.dot(x,y)
x = jnp.ones((1000,1000))
start_time = timeit.default_timer()
FuncA(x,x)
print("Total time taken for normal function:", timeit.default_timer() - start_time)
start_time = timeit.default_timer()
FuncB(x,x)
print("Total time taken for JIT function:", timeit.default_timer() - start_time)

If we run the above code on the CPU, the output should be fairly straightforward.

If we run it on GPU/TPU, on the other hand, the output is quite astonishing. The JAX JIT function is almost three times faster than the “non-JIT” function ...