Skip to content

Set up a GPU data science project with pixi

This tutorial builds a PyTorch image classification project managed by pixi. It loads a pretrained ResNet model, classifies a sample image, and prints the top-5 predictions. The project uses pixi’s multi-environment feature to support both GPU (CUDA) and CPU-only machines from a single pixi.toml.

Prerequisites

Why conda-forge for PyTorch

PyTorch on conda-forge shares the CUDA runtime with other packages. When you install PyTorch and cuDNN from conda-forge, they use the same cuda-toolkit package rather than each bundling its own copy. This produces smaller environments and avoids version conflicts between CUDA-dependent libraries.

Compare this to PyPI, where PyTorch wheels bundle their own CUDA libraries, producing larger downloads and potential conflicts with other GPU packages. See How to Install PyTorch with uv for the PyPI approach.

Create the project

pixi init image_classifier
cd image_classifier

Add base dependencies

Start with the packages both GPU and CPU environments need:

pixi add python pytorch torchvision

This pulls PyTorch from conda-forge. On a machine with CUDA drivers, conda-forge provides CUDA-enabled builds automatically.

Add a PyPI-only package

Some packages aren’t on conda-forge. Add them from PyPI with the --pypi flag:

pixi add --pypi Pillow

Note

In this example, Pillow is available on both conda-forge and PyPI, but the --pypi flag demonstrates how to mix sources. Use --pypi for packages that genuinely exist only on PyPI.

The pixi.toml now has both conda and PyPI dependencies:

pixi.toml
[dependencies]
python = ">=3.13.3,<4"
pytorch = ">=2.6.0,<3"
torchvision = ">=0.21.0,<1"

[pypi-dependencies]
pillow = "*"

Define CPU and GPU environments

Pixi’s multi-environment feature lets you define variants of your project. Add these sections to pixi.toml:

pixi.toml
[feature.cuda]
platforms = ["linux-64"]
system-requirements = {cuda = "12"}

[feature.cuda.dependencies]
cuda-version = "12.*"

[environments]
gpu = ["cuda"]

This creates a gpu environment that includes the cuda feature. The cuda-version pin tells the conda solver to resolve GPU-enabled builds of PyTorch and its dependencies. The system-requirements line ensures pixi only allows this environment on machines with CUDA 12+ drivers. The platforms restriction limits the GPU environment to Linux, where CUDA on conda-forge is supported.

The default environment (used by plain pixi run) gets CPU-only builds automatically since no CUDA dependency is present. Both environments inherit the base [dependencies] (Python, PyTorch, torchvision).

Note

CUDA-enabled builds on conda-forge sometimes lag behind the latest CPU-only builds. If pixi install picks a CPU build for the gpu environment, check whether a CUDA build exists for your Python version and PyTorch version on conda-forge.

Download a sample image

Create a data directory and download a test image:

mkdir data
curl -L -o data/cat.jpg "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/1200px-Cat_November_2010-1a.jpg"

Write the classification script

Create classify.py:

classify.py
import sys
import json
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image


def load_labels():
    """Load ImageNet class labels."""
    url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
    import urllib.request
    with urllib.request.urlopen(url) as response:
        return json.loads(response.read().decode())


def classify_image(image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load pretrained ResNet
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    model.to(device)
    model.eval()

    # Preprocess image
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ])

    image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(image).unsqueeze(0).to(device)

    # Run inference
    with torch.no_grad():
        output = model(input_tensor)

    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top5_prob, top5_idx = torch.topk(probabilities, 5)

    labels = load_labels()

    print(f"\nTop 5 predictions for {image_path}:")
    for i in range(5):
        idx = top5_idx[i].item()
        prob = top5_prob[i].item()
        print(f"  {labels[idx]:30s} {prob:.1%}")


if __name__ == "__main__":
    image_path = sys.argv[1] if len(sys.argv) > 1 else "data/cat.jpg"
    classify_image(image_path)

Run in the default (CPU) environment

pixi run python classify.py data/cat.jpg

Expected output (the model downloads on first run):

Using device: cpu

Top 5 predictions for data/cat.jpg:
  tabby cat                      47.2%
  tiger cat                      30.1%
  Egyptian cat                   8.5%
  lynx                           1.2%
  Persian cat                    0.4%

The exact percentages vary slightly across PyTorch versions.

Run in the GPU environment

On a machine with an NVIDIA GPU and CUDA drivers:

pixi run --environment gpu python classify.py data/cat.jpg

The output looks the same except Using device: cuda. Inference is faster, though for a single image the difference is marginal. GPU acceleration matters when classifying batches of images or training models.

Add a task for classification

pixi.toml
[tasks]
classify = "python classify.py data/cat.jpg"
pixi run classify

Add Jupyter for exploration

pixi add jupyter

Launch Jupyter:

pixi run jupyter lab

Within a notebook, you can import torch and experiment with different models, images, or preprocessing pipelines. The notebook kernel uses the project’s environment automatically.

Final project structure

    • pixi.toml
    • pixi.lock
    • classify.py
    • .gitignore
      • cat.jpg

Next steps

Last updated on

Please submit corrections and feedback...