How to Install JAX with uv
JAX distributes GPU builds through pip extras rather than separate package indexes (see Why Installing GPU Python Packages Is So Complicated for background on why this matters). Running uv add 'jax[cuda13]' pulls CUDA libraries from PyPI with no index configuration in pyproject.toml, unlike PyTorch, which requires [[tool.uv.index]] routing.
Install JAX for CPU
For CPU-only use on any platform (Linux, macOS, Windows):
uv add jaxOr with the uv pip interface:
uv pip install jaxThis installs jax and jaxlib from PyPI. No GPU libraries are included.
Install JAX with CUDA (NVIDIA GPU)
CUDA builds are available for Linux only. Choose the extra that matches your CUDA toolkit version:
uv add 'jax[cuda13]'Or for CUDA 12:
uv add 'jax[cuda12]'The cuda13 extra installs jax-cuda13-plugin, jax-cuda13-pjrt, and NVIDIA runtime libraries (nvidia-cuda-runtime, nvidia-cudnn-cu13, nvidia-nccl-cu13, and others) from PyPI. No [[tool.uv.index]] or [tool.uv.sources] configuration is needed.
Check your NVIDIA driver version
| CUDA version | Minimum NVIDIA driver | Supported GPU architectures |
|---|---|---|
| CUDA 13 | 580+ | SM 7.5+ (Turing and newer) |
| CUDA 12 | 525+ | SM 5.2+ (Maxwell and newer) |
Check the installed driver version with nvidia-smi.
Use a locally installed CUDA toolkit
If the target machine already has CUDA installed system-wide (for example, in an HPC cluster), use the -local extras to skip downloading CUDA pip wheels:
uv add 'jax[cuda13-local]'This installs jax-cuda13-plugin without the [with-cuda] extra, so no nvidia-cuda-runtime or nvidia-cudnn pip packages are pulled in. The system CUDA toolkit must be on LD_LIBRARY_PATH.
Install JAX for Google Cloud TPU
On a Google Cloud TPU VM, install with the tpu extra:
uv add 'jax[tpu]'This installs libtpu alongside jaxlib. TPU support requires running on a Google Cloud TPU VM.
Support both CPU and GPU in one project
For projects that run on Linux with a GPU and on macOS or Windows without one, use environment markers to select the right extra per platform:
[project]
dependencies = [
"jax[cuda13]; sys_platform == 'linux'",
"jax; sys_platform != 'linux'",
]Running uv sync on Linux installs the CUDA build. On macOS or Windows, it installs CPU-only JAX. The lock file captures both resolution paths.
Verify the installation
Confirm JAX is installed and check which device backend is active:
uv run python -c "import jax; print(jax.__version__); print(jax.devices())"On a CPU-only install, this prints [CpuDevice(id=0)]. On a machine with a CUDA-enabled GPU, it prints [CudaDevice(id=0)] (or multiple devices for multi-GPU systems).
Check Python and platform compatibility
JAX 0.10.0 requires Python 3.11 or later. It publishes wheels for CPython 3.11 through 3.14, including free-threaded builds (3.13t and 3.14t). jaxlib publishes wheels for Linux (x86_64 and aarch64), macOS (ARM64), and Windows (x86_64).
Learn More
- Official JAX installation guide
- Why Installing GPU Python Packages Is So Complicated
- How to Install PyTorch with uv for the index-routing approach that PyTorch requires
- uv vs. pixi vs. conda for Scientific Python
- JAX GitHub repository