diff --git a/.github/workflows/_test_requiring_torch_cuda.yml b/.github/workflows/_test_requiring_torch_cuda.yml index 9861e4fba161..79f5e05fa6d6 100644 --- a/.github/workflows/_test_requiring_torch_cuda.yml +++ b/.github/workflows/_test_requiring_torch_cuda.yml @@ -85,6 +85,9 @@ jobs: echo "Check if CUDA is available for PyTorch..." python -c "import torch; assert torch.cuda.is_available()" echo "CUDA is available for PyTorch." + echo "Check if CUDA is available for PyTorch/XLA..." + PJRT_DEVICE=CUDA python -c "import torch; import torch_xla; print(torch.tensor([1,2,3], device='xla')); assert torch_xla.runtime.device_type() == 'CUDA'" + echo "CUDA is available for PyTorch/XLA." - name: Checkout PyTorch Repo if: inputs.has_code_changes == 'true' uses: actions/checkout@v4 @@ -97,12 +100,6 @@ jobs: uses: actions/checkout@v4 with: path: pytorch/xla - - name: Extra CI deps - if: inputs.has_code_changes == 'true' && matrix.run_triton_tests - shell: bash - run: | - set -x - pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - name: Install Triton if: inputs.has_code_changes == 'true' shell: bash diff --git a/build_util.py b/build_util.py index 487f5116323e..edb22f22ac23 100644 --- a/build_util.py +++ b/build_util.py @@ -1,8 +1,154 @@ import os -from typing import Iterable +from collections.abc import Iterable import subprocess import sys import shutil +from dataclasses import dataclass +import functools + +import platform + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +@functools.lru_cache +def get_pinned_packages(): + """Gets the versions of important pinned dependencies of torch_xla.""" + return PinnedPackages( + use_nightly=True, + date='20250424', + raw_libtpu_version='0.0.14', + raw_jax_version='0.6.1', + raw_jaxlib_version='0.6.1', + ) + + +@functools.lru_cache +def get_build_version(): + xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR) + version = os.getenv('TORCH_XLA_VERSION', '2.8.0') + if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): + try: + version += '+git' + xla_git_sha[:7] + except Exception: + pass + return version + + +@functools.lru_cache +def get_git_head_sha(base_dir): + xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=base_dir).decode('ascii').strip() + if os.path.isdir(os.path.join(base_dir, '..', '.git')): + torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + cwd=os.path.join( + base_dir, + '..')).decode('ascii').strip() + else: + torch_git_sha = '' + return xla_git_sha, torch_git_sha + + +@functools.lru_cache +def get_jax_install_requirements(): + """Get a list of JAX requirements for use in setup.py without extra package registries.""" + pinned_packages = get_pinned_packages() + if not pinned_packages.use_nightly: + # Stable versions of JAX can be directly installed from PyPI. + return [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}', + ] + + # Install nightly JAX libraries from the JAX package registries. + # TODO(https://github.com/pytorch/xla/issues/9064): This URL needs to be + # updated to use the new JAX package registry for any JAX builds after Apr 28, 2025. + jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl' + jaxlib = [] + for python_minor_version in [9, 10, 11]: + jaxlib.append( + f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + return [jax] + jaxlib + + +@functools.lru_cache +def get_jax_cuda_requirements(): + """Get a list of JAX CUDA requirements for use in setup.py without extra package registries.""" + pinned_packages = get_pinned_packages() + jax_requirements = get_jax_install_requirements() + + # Install nightly JAX CUDA libraries. + jax_cuda = [ + f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' + ] + for python_minor_version in [9, 10, 11]: + jax_cuda.append( + f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' + ) + + return jax_requirements + jax_cuda + + +@dataclass(eq=True, frozen=True) +class PinnedPackages: + use_nightly: bool + """Whether to use nightly or stable libtpu and JAX""" + + date: str + """The date of the libtpu and jax build""" + + raw_libtpu_version: str + """libtpu version string in [major].[minor].[patch] format.""" + + raw_jax_version: str + """jax version string in [major].[minor].[patch] format.""" + + raw_jaxlib_version: str + """jaxlib version string in [major].[minor].[patch] format.""" + + @property + def libtpu_version(self) -> str: + if self.use_nightly: + return f'{self.raw_libtpu_version}.dev{self.date}' + else: + return self.raw_libtpu_version + + @property + def jax_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jax_version}.dev{self.date}' + else: + return self.raw_jax_version + + @property + def jaxlib_version(self) -> str: + if self.use_nightly: + return f'{self.raw_jaxlib_version}.dev{self.date}' + else: + return self.raw_jaxlib_version + + @property + def libtpu_storage_directory(self) -> str: + if self.use_nightly: + return 'libtpu-nightly-releases' + else: + return 'libtpu-lts-releases' + + @property + def libtpu_wheel_name(self) -> str: + if self.use_nightly: + return f'libtpu-{self.libtpu_version}+nightly' + else: + return f'libtpu-{self.libtpu_version}' + + @property + def libtpu_storage_path(self) -> str: + platform_machine = platform.machine() + # The suffix can be changed when the version is updated. Check + # https://storage.googleapis.com/libtpu-wheels/index.html for correct name. + suffix = f"py3-none-manylinux_2_31_{platform_machine}" + return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-{suffix}.whl' def check_env_flag(name: str, default: str = '') -> bool: @@ -60,7 +206,7 @@ def bazel_build(bazel_target: str, ] # Remove duplicated flags because they confuse bazel - flags = set(bazel_options_from_env() + options) + flags = set(list(bazel_options_from_env()) + list(options)) bazel_argv.extend(flags) print(' '.join(bazel_argv), flush=True) diff --git a/infra/ansible/config/cuda_deps.yaml b/infra/ansible/config/cuda_deps.yaml index 3732bb0f93ec..9609d96e6209 100644 --- a/infra/ansible/config/cuda_deps.yaml +++ b/infra/ansible/config/cuda_deps.yaml @@ -1,22 +1,22 @@ # Versions of cuda dependencies for given cuda versions. # Note: wrap version in quotes to ensure they're treated as strings. cuda_deps: - # List all libcudnn8 versions with `apt list -a libcudnn8` + # Find package versions from https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ libcudnn: - "12.8": libcudnn9-cuda-12=9.1.1.17-1 - "12.6": libcudnn9-cuda-12=9.1.1.17-1 - "12.4": libcudnn9-cuda-12=9.1.1.17-1 - "12.3": libcudnn9-cuda-12=9.1.1.17-1 + "12.8": libcudnn9-cuda-12=9.8.0.87-1 + "12.6": libcudnn9-cuda-12=9.8.0.87-1 + "12.4": libcudnn9-cuda-12=9.8.0.87-1 + "12.3": libcudnn9-cuda-12=9.8.0.87-1 "12.1": libcudnn8=8.9.2.26-1+cuda12.1 "12.0": libcudnn8=8.8.0.121-1+cuda12.0 "11.8": libcudnn8=8.7.0.84-1+cuda11.8 "11.7": libcudnn8=8.5.0.96-1+cuda11.7 "11.2": libcudnn8=8.1.1.33-1+cuda11.2 libcudnn-dev: - "12.8": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.6": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.4": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.3": libcudnn9-dev-cuda-12=9.1.1.17-1 + "12.8": libcudnn9-dev-cuda-12=9.8.0.87-1 + "12.6": libcudnn9-dev-cuda-12=9.8.0.87-1 + "12.4": libcudnn9-dev-cuda-12=9.8.0.87-1 + "12.3": libcudnn9-dev-cuda-12=9.8.0.87-1 "12.1": libcudnn8-dev=8.9.2.26-1+cuda12.1 "12.0": libcudnn8-dev=8.8.0.121-1+cuda12.0 "11.8": libcudnn8-dev=8.7.0.84-1+cuda11.8 diff --git a/infra/ansible/roles/build_plugin/tasks/main.yaml b/infra/ansible/roles/build_plugin/tasks/main.yaml index 142d29c3718f..65282bbe1520 100644 --- a/infra/ansible/roles/build_plugin/tasks/main.yaml +++ b/infra/ansible/roles/build_plugin/tasks/main.yaml @@ -6,12 +6,27 @@ - name: Build PyTorch/XLA CUDA Plugin ansible.builtin.command: - cmd: pip wheel -w /dist plugins/cuda -v + cmd: pip wheel plugins/cuda -v chdir: "{{ (src_root, 'pytorch/xla') | path_join }}" environment: "{{ env_vars }}" when: accelerator == "cuda" -- name: Find CUDA plugin wheel pytorch/xla/dist +- name: Find the built CUDA plugin wheel + ansible.builtin.find: + paths: "{{ (src_root, 'pytorch/xla') | path_join }}" # Look in the dir where pip saved the wheel + patterns: "torch_xla_cuda_plugin-*.whl" + recurse: no + when: accelerator == "cuda" + register: built_plugin_wheel_info + +- name: Copy the CUDA plugin wheel to /dist + ansible.builtin.copy: + src: "{{ item.path }}" + dest: "/dist/{{ item.path | basename }}" # Ensure only the filename is used for dest + loop: "{{ built_plugin_wheel_info.files }}" + when: accelerator == "cuda" and built_plugin_wheel_info.files | length > 0 + +- name: Find CUDA plugin wheel in /dist ansible.builtin.find: path: "/dist" pattern: "torch_xla_cuda_plugin*.whl" diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml index d44a2ea3bd53..a5e65c6bc69d 100644 --- a/plugins/cuda/pyproject.toml +++ b/plugins/cuda/pyproject.toml @@ -9,7 +9,7 @@ authors = [ ] description = "PyTorch/XLA CUDA Plugin" requires-python = ">=3.8" -dynamic = ["version"] +dynamic = ["version", "dependencies"] [tool.setuptools.package-data] torch_xla_cuda_plugin = ["lib/*.so"] diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py index 2652880c6fd7..6e075e3bbf57 100644 --- a/plugins/cuda/setup.py +++ b/plugins/cuda/setup.py @@ -1,4 +1,3 @@ -import datetime import os import sys @@ -12,6 +11,6 @@ 'torch_xla_cuda_plugin/lib', ['--config=cuda']) setuptools.setup( - # TODO: Use a common version file - version=os.getenv('TORCH_XLA_VERSION', - f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}')) + version=build_util.get_build_version(), + install_requires=build_util.get_jax_cuda_requirements(), +) diff --git a/setup.py b/setup.py index 3eae00e2796a..ec9953fa68af 100644 --- a/setup.py +++ b/setup.py @@ -56,41 +56,14 @@ import re import requests import shutil -import subprocess import sys import tempfile import zipfile import build_util -import platform - -platform_machine = platform.machine() - base_dir = os.path.dirname(os.path.abspath(__file__)) - -USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax - -_date = '20250424' - -_libtpu_version = '0.0.14' -_jax_version = '0.6.1' -_jaxlib_version = '0.6.1' - -if USE_NIGHTLY: - _libtpu_version += f".dev{_date}" - _jax_version += f'.dev{_date}' - _jaxlib_version += f'.dev{_date}' - _libtpu_wheel_name = f'libtpu-{_libtpu_version}.dev{_date}+nightly-py3-none-manylinux_2_31_{platform_machine}' - _libtpu_storage_directory = 'libtpu-nightly-releases' -else: - # The postfix can be changed when the version is updated. Check - # https://storage.googleapis.com/libtpu-wheels/index.html for correct - # versioning. - _libtpu_wheel_name = f'libtpu-{_libtpu_version}-py3-none-manylinux_2_31_{platform_machine}' - _libtpu_storage_directory = 'libtpu-lts-releases' - -_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}.whl' +pinned_packages = build_util.get_pinned_packages() def _get_build_mode(): @@ -99,29 +72,6 @@ def _get_build_mode(): return sys.argv[i] -def get_git_head_sha(base_dir): - xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=base_dir).decode('ascii').strip() - if os.path.isdir(os.path.join(base_dir, '..', '.git')): - torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], - cwd=os.path.join( - base_dir, - '..')).decode('ascii').strip() - else: - torch_git_sha = '' - return xla_git_sha, torch_git_sha - - -def get_build_version(xla_git_sha): - version = os.getenv('TORCH_XLA_VERSION', '2.8.0') - if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): - try: - version += '+git' + xla_git_sha[:7] - except Exception: - pass - return version - - def create_version_files(base_dir, version, xla_git_sha, torch_git_sha): print('Building torch_xla version: {}'.format(version)) print('XLA Commit ID: {}'.format(xla_git_sha)) @@ -160,7 +110,7 @@ def maybe_bundle_libtpu(base_dir): print('No installed libtpu found. Downloading...') with tempfile.NamedTemporaryFile('wb') as whl: - resp = requests.get(_libtpu_storage_path) + resp = requests.get(pinned_packages.libtpu_storage_path) resp.raise_for_status() whl.write(resp.content) @@ -203,8 +153,8 @@ def run(self): distutils.command.clean.clean.run(self) -xla_git_sha, torch_git_sha = get_git_head_sha(base_dir) -version = get_build_version(xla_git_sha) +xla_git_sha, torch_git_sha = build_util.get_git_head_sha(base_dir) +version = build_util.get_build_version() build_mode = _get_build_mode() if build_mode not in ['clean']: @@ -353,24 +303,6 @@ def link_packages(self): f.write(path + "\n") -def _get_jax_install_requirements(): - if not USE_NIGHTLY: - # Stable versions of JAX can be directly installed from PyPI. - return [ - f'jaxlib=={_jaxlib_version}', - f'jax=={_jax_version}', - ] - - # Install nightly JAX libraries from the JAX package registries. - jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{_jax_version}-py3-none-any.whl' - jaxlib = [] - for python_minor_version in [9, 10, 11]: - jaxlib.append( - f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' - ) - return [jax] + jaxlib - - setup( name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'), version=version, @@ -411,7 +343,7 @@ def _get_jax_install_requirements(): # to Python 3.10 'importlib_metadata>=4.6;python_version<"3.10"', # Some torch operations are lowered to HLO via JAX. - *_get_jax_install_requirements(), + *build_util.get_jax_install_requirements(), ], package_data={ 'torch_xla': ['lib/*.so*',], @@ -430,13 +362,16 @@ def _get_jax_install_requirements(): # On Cloud TPU VM install with: # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html 'tpu': [ - f'libtpu=={_libtpu_version}', + f'libtpu=={pinned_packages.libtpu_version}', 'tpu-info', ], # As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla. # However, this no-op extras_require entrypoint is left here for backwards compatibility. # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - 'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'], + 'pallas': [ + f'jaxlib=={pinned_packages.jaxlib_version}', + f'jax=={pinned_packages.jax_version}' + ], }, cmdclass={ 'build_ext': BuildBazelExtension,