Skip to content

Make GPU CUDA plugin require JAX #8919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/_test_requiring_torch_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ jobs:
echo "Check if CUDA is available for PyTorch..."
python -c "import torch; assert torch.cuda.is_available()"
echo "CUDA is available for PyTorch."
echo "Check if CUDA is available for PyTorch/XLA..."
PJRT_DEVICE=CUDA python -c "import torch; import torch_xla; print(torch.tensor([1,2,3], device='xla')); assert torch_xla.runtime.device_type() == 'CUDA'"
echo "CUDA is available for PyTorch/XLA."
- name: Checkout PyTorch Repo
if: inputs.has_code_changes == 'true'
uses: actions/checkout@v4
Expand All @@ -97,12 +100,6 @@ jobs:
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Extra CI deps
if: inputs.has_code_changes == 'true' && matrix.run_triton_tests
shell: bash
run: |
set -x
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
- name: Install Triton
if: inputs.has_code_changes == 'true'
shell: bash
Expand Down
150 changes: 148 additions & 2 deletions build_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,154 @@
import os
from typing import Iterable
from collections.abc import Iterable
import subprocess
import sys
import shutil
from dataclasses import dataclass
import functools

import platform

BASE_DIR = os.path.dirname(os.path.abspath(__file__))


@functools.lru_cache
def get_pinned_packages():
"""Gets the versions of important pinned dependencies of torch_xla."""
return PinnedPackages(
use_nightly=True,
date='20250424',
raw_libtpu_version='0.0.14',
raw_jax_version='0.6.1',
raw_jaxlib_version='0.6.1',
)


@functools.lru_cache
def get_build_version():
xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR)
version = os.getenv('TORCH_XLA_VERSION', '2.8.0')
if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
try:
version += '+git' + xla_git_sha[:7]
except Exception:
pass
return version


@functools.lru_cache
def get_git_head_sha(base_dir):
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=base_dir).decode('ascii').strip()
if os.path.isdir(os.path.join(base_dir, '..', '.git')):
torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=os.path.join(
base_dir,
'..')).decode('ascii').strip()
else:
torch_git_sha = ''
return xla_git_sha, torch_git_sha


@functools.lru_cache
def get_jax_install_requirements():
"""Get a list of JAX requirements for use in setup.py without extra package registries."""
pinned_packages = get_pinned_packages()
if not pinned_packages.use_nightly:
# Stable versions of JAX can be directly installed from PyPI.
return [
f'jaxlib=={pinned_packages.jaxlib_version}',
f'jax=={pinned_packages.jax_version}',
]

# Install nightly JAX libraries from the JAX package registries.
# TODO(https://github.com/pytorch/xla/issues/9064): This URL needs to be
# updated to use the new JAX package registry for any JAX builds after Apr 28, 2025.
jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl'
jaxlib = []
for python_minor_version in [9, 10, 11]:
jaxlib.append(
f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
)
return [jax] + jaxlib


@functools.lru_cache
def get_jax_cuda_requirements():
"""Get a list of JAX CUDA requirements for use in setup.py without extra package registries."""
pinned_packages = get_pinned_packages()
jax_requirements = get_jax_install_requirements()

# Install nightly JAX CUDA libraries.
jax_cuda = [
f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl'
]
for python_minor_version in [9, 10, 11]:
jax_cuda.append(
f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
)

return jax_requirements + jax_cuda


@dataclass(eq=True, frozen=True)
class PinnedPackages:
use_nightly: bool
"""Whether to use nightly or stable libtpu and JAX"""

date: str
"""The date of the libtpu and jax build"""

raw_libtpu_version: str
"""libtpu version string in [major].[minor].[patch] format."""

raw_jax_version: str
"""jax version string in [major].[minor].[patch] format."""

raw_jaxlib_version: str
"""jaxlib version string in [major].[minor].[patch] format."""

@property
def libtpu_version(self) -> str:
if self.use_nightly:
return f'{self.raw_libtpu_version}.dev{self.date}'
else:
return self.raw_libtpu_version

@property
def jax_version(self) -> str:
if self.use_nightly:
return f'{self.raw_jax_version}.dev{self.date}'
else:
return self.raw_jax_version

@property
def jaxlib_version(self) -> str:
if self.use_nightly:
return f'{self.raw_jaxlib_version}.dev{self.date}'
else:
return self.raw_jaxlib_version

@property
def libtpu_storage_directory(self) -> str:
if self.use_nightly:
return 'libtpu-nightly-releases'
else:
return 'libtpu-lts-releases'

@property
def libtpu_wheel_name(self) -> str:
if self.use_nightly:
return f'libtpu-{self.libtpu_version}+nightly'
else:
return f'libtpu-{self.libtpu_version}'

@property
def libtpu_storage_path(self) -> str:
platform_machine = platform.machine()
# The suffix can be changed when the version is updated. Check
# https://storage.googleapis.com/libtpu-wheels/index.html for correct name.
suffix = f"py3-none-manylinux_2_31_{platform_machine}"
return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-{suffix}.whl'


def check_env_flag(name: str, default: str = '') -> bool:
Expand Down Expand Up @@ -60,7 +206,7 @@ def bazel_build(bazel_target: str,
]

# Remove duplicated flags because they confuse bazel
flags = set(bazel_options_from_env() + options)
flags = set(list(bazel_options_from_env()) + list(options))
bazel_argv.extend(flags)

print(' '.join(bazel_argv), flush=True)
Expand Down
18 changes: 9 additions & 9 deletions infra/ansible/config/cuda_deps.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# Versions of cuda dependencies for given cuda versions.
# Note: wrap version in quotes to ensure they're treated as strings.
cuda_deps:
# List all libcudnn8 versions with `apt list -a libcudnn8`
# Find package versions from https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/
libcudnn:
"12.8": libcudnn9-cuda-12=9.1.1.17-1
"12.6": libcudnn9-cuda-12=9.1.1.17-1
"12.4": libcudnn9-cuda-12=9.1.1.17-1
"12.3": libcudnn9-cuda-12=9.1.1.17-1
"12.8": libcudnn9-cuda-12=9.8.0.87-1
"12.6": libcudnn9-cuda-12=9.8.0.87-1
"12.4": libcudnn9-cuda-12=9.8.0.87-1
"12.3": libcudnn9-cuda-12=9.8.0.87-1
"12.1": libcudnn8=8.9.2.26-1+cuda12.1
"12.0": libcudnn8=8.8.0.121-1+cuda12.0
"11.8": libcudnn8=8.7.0.84-1+cuda11.8
"11.7": libcudnn8=8.5.0.96-1+cuda11.7
"11.2": libcudnn8=8.1.1.33-1+cuda11.2
libcudnn-dev:
"12.8": libcudnn9-dev-cuda-12=9.1.1.17-1
"12.6": libcudnn9-dev-cuda-12=9.1.1.17-1
"12.4": libcudnn9-dev-cuda-12=9.1.1.17-1
"12.3": libcudnn9-dev-cuda-12=9.1.1.17-1
"12.8": libcudnn9-dev-cuda-12=9.8.0.87-1
"12.6": libcudnn9-dev-cuda-12=9.8.0.87-1
"12.4": libcudnn9-dev-cuda-12=9.8.0.87-1
"12.3": libcudnn9-dev-cuda-12=9.8.0.87-1
"12.1": libcudnn8-dev=8.9.2.26-1+cuda12.1
"12.0": libcudnn8-dev=8.8.0.121-1+cuda12.0
"11.8": libcudnn8-dev=8.7.0.84-1+cuda11.8
Expand Down
19 changes: 17 additions & 2 deletions infra/ansible/roles/build_plugin/tasks/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,27 @@

- name: Build PyTorch/XLA CUDA Plugin
ansible.builtin.command:
cmd: pip wheel -w /dist plugins/cuda -v
cmd: pip wheel plugins/cuda -v
chdir: "{{ (src_root, 'pytorch/xla') | path_join }}"
environment: "{{ env_vars }}"
when: accelerator == "cuda"

- name: Find CUDA plugin wheel pytorch/xla/dist
- name: Find the built CUDA plugin wheel
ansible.builtin.find:
paths: "{{ (src_root, 'pytorch/xla') | path_join }}" # Look in the dir where pip saved the wheel
patterns: "torch_xla_cuda_plugin-*.whl"
recurse: no
when: accelerator == "cuda"
register: built_plugin_wheel_info

- name: Copy the CUDA plugin wheel to /dist
ansible.builtin.copy:
src: "{{ item.path }}"
dest: "/dist/{{ item.path | basename }}" # Ensure only the filename is used for dest
loop: "{{ built_plugin_wheel_info.files }}"
when: accelerator == "cuda" and built_plugin_wheel_info.files | length > 0

- name: Find CUDA plugin wheel in /dist
ansible.builtin.find:
path: "/dist"
pattern: "torch_xla_cuda_plugin*.whl"
Expand Down
2 changes: 1 addition & 1 deletion plugins/cuda/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
]
description = "PyTorch/XLA CUDA Plugin"
requires-python = ">=3.8"
dynamic = ["version"]
dynamic = ["version", "dependencies"]

[tool.setuptools.package-data]
torch_xla_cuda_plugin = ["lib/*.so"]
Expand Down
7 changes: 3 additions & 4 deletions plugins/cuda/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import os
import sys

Expand All @@ -12,6 +11,6 @@
'torch_xla_cuda_plugin/lib', ['--config=cuda'])

setuptools.setup(
# TODO: Use a common version file
version=os.getenv('TORCH_XLA_VERSION',
f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}'))
version=build_util.get_build_version(),
install_requires=build_util.get_jax_cuda_requirements(),
)
Loading
Loading