Auto Parallelization
This lesson will introduce auto parallelization in JAX.
We'll cover the following
With ever-increasing data and computational resources, there is always a pressing need for parallel processing. Luckily, JAX facilitates this feature as well. Just like vmap()
, we can use pmap()
for parallel execution of a given function.
Remember: Since this lesson’s subjects,
vmap()
andpmap()
, require GPU/TPU support, the code snippets are provided just as a guide here. Executing them on a normal, non-GPU, machine will result in an error!
Get hands-on with 1400+ tech skills courses.