Search⌘ K

Installation

Explore the installation methods for JAX across CPU, GPU with CUDA, and TPU setups. Understand how to upgrade pip, install CUDA and CuDNN, and build JAX from source with required dependencies to set up your deep learning environment.

Installing JAX is ridiculously easy, similar to NumPy or any other Python library.

CPU-only installation

If you don’t have GPU support, you can still code JAX on your CPU.

  1. Upgrade the pip: pip install --upgrade pip
  2. Install JAX: pip install --upgrade "jax[cpu]"
Shell
pip install --upgrade pip
pip install --upgrade "jax[cpu]"

GPU installation

Like other deep learning systems, JAX includes support for CUDA, though we have to install the CUDA and CuDNN ourselves.

Installing CUDA

Since CUDA depends on a number of parameters ...