Asynchronous Dispatch
This lesson will introduce asynchronous dispatch.
We'll cover the following...
Introductory example
We have already talked about the JIT’s usefulness. Now let’s take a look at the difference in speed between a JIT and a normal function:
Press + to interact
#Normal Functiondef FuncA(x,y):return np.asarray(np.dot(x,y))#And JIT version@jitdef FuncB(x,y):return jnp.dot(x,y)x = jnp.ones((1000,1000))start_time = timeit.default_timer()FuncA(x,x)print("Total time taken for normal function:", timeit.default_timer() - start_time)start_time = timeit.default_timer()FuncB(x,x)print("Total time taken for JIT function:", timeit.default_timer() - start_time)
If we run the above code on the CPU, the output should be fairly straightforward.
If we run it on GPU/TPU, on the other hand, the output is quite astonishing. The JAX JIT function is almost three times faster than the “non-JIT” function ...