Skip to content

How to Install JAX with uv

JAX distributes GPU builds through pip extras rather than separate package indexes (see Why Installing GPU Python Packages Is So Complicated for background on why this matters). Running uv add 'jax[cuda13]' pulls CUDA libraries from PyPI with no index configuration in pyproject.toml, unlike PyTorch, which requires [[tool.uv.index]] routing.

Install JAX for CPU

For CPU-only use on any platform (Linux, macOS, Windows):

uv add jax

Or with the uv pip interface:

uv pip install jax

This installs jax and jaxlib from PyPI. No GPU libraries are included.

Install JAX with CUDA (NVIDIA GPU)

CUDA builds are available for Linux only. Choose the extra that matches your CUDA toolkit version:

uv add 'jax[cuda13]'

Or for CUDA 12:

uv add 'jax[cuda12]'

The cuda13 extra installs jax-cuda13-plugin, jax-cuda13-pjrt, and NVIDIA runtime libraries (nvidia-cuda-runtime, nvidia-cudnn-cu13, nvidia-nccl-cu13, and others) from PyPI. No [[tool.uv.index]] or [tool.uv.sources] configuration is needed.

Check your NVIDIA driver version

CUDA version Minimum NVIDIA driver Supported GPU architectures
CUDA 13 580+ SM 7.5+ (Turing and newer)
CUDA 12 525+ SM 5.2+ (Maxwell and newer)

Check the installed driver version with nvidia-smi.

Use a locally installed CUDA toolkit

If the target machine already has CUDA installed system-wide (for example, in an HPC cluster), use the -local extras to skip downloading CUDA pip wheels:

uv add 'jax[cuda13-local]'

This installs jax-cuda13-plugin without the [with-cuda] extra, so no nvidia-cuda-runtime or nvidia-cudnn pip packages are pulled in. The system CUDA toolkit must be on LD_LIBRARY_PATH.

Install JAX for Google Cloud TPU

On a Google Cloud TPU VM, install with the tpu extra:

uv add 'jax[tpu]'

This installs libtpu alongside jaxlib. TPU support requires running on a Google Cloud TPU VM.

Support both CPU and GPU in one project

For projects that run on Linux with a GPU and on macOS or Windows without one, use environment markers to select the right extra per platform:

[project]
dependencies = [
    "jax[cuda13]; sys_platform == 'linux'",
    "jax; sys_platform != 'linux'",
]

Running uv sync on Linux installs the CUDA build. On macOS or Windows, it installs CPU-only JAX. The lock file captures both resolution paths.

Verify the installation

Confirm JAX is installed and check which device backend is active:

uv run python -c "import jax; print(jax.__version__); print(jax.devices())"

On a CPU-only install, this prints [CpuDevice(id=0)]. On a machine with a CUDA-enabled GPU, it prints [CudaDevice(id=0)] (or multiple devices for multi-GPU systems).

Check Python and platform compatibility

JAX 0.10.0 requires Python 3.11 or later. It publishes wheels for CPython 3.11 through 3.14, including free-threaded builds (3.13t and 3.14t). jaxlib publishes wheels for Linux (x86_64 and aarch64), macOS (ARM64), and Windows (x86_64).

Learn More

Last updated on