# How to Install JAX with uv


JAX distributes GPU builds through [pip extras](https://pydevtools.com/handbook/explanation/what-are-optional-dependencies-and-dependency-groups.md) rather than separate package indexes (see [Why Installing GPU Python Packages Is So Complicated](https://pydevtools.com/handbook/explanation/installing-cuda-python-packages.md) for background on why this matters). Running `uv add 'jax[cuda13]'` pulls CUDA libraries from PyPI with no index configuration in `pyproject.toml`, unlike [PyTorch](https://pydevtools.com/handbook/how-to/how-to-install-pytorch-with-uv.md), which requires `[[tool.uv.index]]` routing.

## Install JAX for CPU

For CPU-only use on any platform (Linux, macOS, Windows):

```bash
uv add jax
```

Or with the [uv pip](https://pydevtools.com/handbook/reference/uv.md) interface:

```bash
uv pip install jax
```

This installs `jax` and `jaxlib` from [PyPI](https://pydevtools.com/handbook/explanation/what-is-pypi.md). No GPU libraries are included.

## Install JAX with CUDA (NVIDIA GPU)

CUDA builds are available for Linux only. Choose the extra that matches your CUDA toolkit version:

```bash
uv add 'jax[cuda13]'
```

Or for CUDA 12:

```bash
uv add 'jax[cuda12]'
```

The `cuda13` extra installs `jax-cuda13-plugin`, `jax-cuda13-pjrt`, and NVIDIA runtime libraries (`nvidia-cuda-runtime`, `nvidia-cudnn-cu13`, `nvidia-nccl-cu13`, and others) from PyPI. No `[[tool.uv.index]]` or `[tool.uv.sources]` configuration is needed.

### Check your NVIDIA driver version

| CUDA version | Minimum NVIDIA driver | Supported GPU architectures |
|---|---|---|
| CUDA 13 | 580+ | SM 7.5+ (Turing and newer) |
| CUDA 12 | 525+ | SM 5.2+ (Maxwell and newer) |

Check the installed driver version with `nvidia-smi`.

### Use a locally installed CUDA toolkit

If the target machine already has CUDA installed system-wide (for example, in an HPC cluster), use the `-local` extras to skip downloading CUDA pip wheels:

```bash
uv add 'jax[cuda13-local]'
```

This installs `jax-cuda13-plugin` without the `[with-cuda]` extra, so no `nvidia-cuda-runtime` or `nvidia-cudnn` pip packages are pulled in. The system CUDA toolkit must be on `LD_LIBRARY_PATH`.

## Install JAX for Google Cloud TPU

On a Google Cloud TPU VM, install with the `tpu` extra:

```bash
uv add 'jax[tpu]'
```

This installs `libtpu` alongside `jaxlib`. TPU support requires running on a Google Cloud TPU VM.

## Support both CPU and GPU in one project

For projects that run on Linux with a GPU and on macOS or Windows without one, use [environment markers](https://pydevtools.com/handbook/explanation/what-are-optional-dependencies-and-dependency-groups.md) to select the right extra per platform:

```toml
[project]
dependencies = [
    "jax[cuda13]; sys_platform == 'linux'",
    "jax; sys_platform != 'linux'",
]
```

Running `uv sync` on Linux installs the CUDA build. On macOS or Windows, it installs CPU-only JAX. The [lock file](https://pydevtools.com/handbook/explanation/what-is-a-lock-file.md) captures both resolution paths.

## Verify the installation

Confirm JAX is installed and check which device backend is active:

```bash
uv run python -c "import jax; print(jax.__version__); print(jax.devices())"
```

On a CPU-only install, this prints `[CpuDevice(id=0)]`. On a machine with a CUDA-enabled GPU, it prints `[CudaDevice(id=0)]` (or multiple devices for multi-GPU systems).

## Check Python and platform compatibility

JAX 0.10.0 requires Python 3.11 or later. It publishes wheels for CPython 3.11 through 3.14, including free-threaded builds (3.13t and 3.14t). `jaxlib` publishes wheels for Linux (x86_64 and aarch64), macOS (ARM64), and Windows (x86_64).

## Learn More

- [uv: A Complete Guide](https://pydevtools.com/handbook/explanation/uv-complete-guide.md) covers what uv does, how fast it is, the core workflows, and recent releases.
- [Official JAX installation guide](https://docs.jax.dev/en/latest/installation.html)
- [Why Installing GPU Python Packages Is So Complicated](https://pydevtools.com/handbook/explanation/installing-cuda-python-packages.md)
- [How to Install PyTorch with uv](https://pydevtools.com/handbook/how-to/how-to-install-pytorch-with-uv.md) for the index-routing approach that PyTorch requires
- [uv vs. pixi vs. conda for Scientific Python](https://pydevtools.com/handbook/explanation/uv-vs-pixi-vs-conda-for-scientific-python.md)
- [JAX GitHub repository](https://github.com/jax-ml/jax)
