diff --git a/.ci/docker/README.md b/.ci/docker/README.md index 6e0dcfd6b25d7..56d6ff1f2fe27 100644 --- a/.ci/docker/README.md +++ b/.ci/docker/README.md @@ -25,7 +25,6 @@ See `build.sh` for valid build environments (it's the giant switch). * `conda` - Dockerfile and build.sh to build Docker images used in nightly conda builds * `manywheel` - Dockerfile and build.sh to build Docker images used in nightly manywheel builds -* `libtorch` - Dockerfile and build.sh to build Docker images used in nightly libtorch builds ## Usage diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile deleted file mode 100644 index 3f9d156965129..0000000000000 --- a/.ci/docker/libtorch/Dockerfile +++ /dev/null @@ -1,122 +0,0 @@ -ARG BASE_TARGET=base -ARG GPU_IMAGE=ubuntu:20.04 -FROM ${GPU_IMAGE} as base - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get clean && apt-get update -RUN apt-get install -y curl locales g++ git-all autoconf automake make cmake wget unzip sudo -# Just add everything as a safe.directory for git since these will be used in multiple places with git -RUN git config --global --add safe.directory '*' - -RUN locale-gen en_US.UTF-8 - -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# Install openssl -FROM base as openssl -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh - -# Install python -FROM base as python -ADD common/install_cpython.sh install_cpython.sh -RUN apt-get update -y && \ - apt-get install build-essential gdb lcov libbz2-dev libffi-dev \ - libgdbm-dev liblzma-dev libncurses5-dev libreadline6-dev \ - libsqlite3-dev libssl-dev lzma lzma-dev tk-dev uuid-dev zlib1g-dev -y && \ - bash ./install_cpython.sh && \ - rm install_cpython.sh && \ - apt-get clean - -FROM base as conda -ADD ./common/install_conda_docker.sh install_conda.sh -RUN bash ./install_conda.sh && rm install_conda.sh - -FROM base as cpu -# Install Anaconda -COPY --from=conda /opt/conda /opt/conda -# Install python -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH -# Install MKL -ADD ./common/install_mkl.sh install_mkl.sh -RUN bash ./install_mkl.sh && rm install_mkl.sh - -FROM cpu as cuda -ADD ./common/install_cuda.sh install_cuda.sh -ADD ./common/install_magma.sh install_magma.sh -COPY ./common/install_nccl.sh install_nccl.sh -COPY ./ci_commit_pins/nccl* /ci_commit_pins/ -COPY ./common/install_cusparselt.sh install_cusparselt.sh -ENV CUDA_HOME /usr/local/cuda - -FROM cuda as cuda12.6 -RUN bash ./install_cuda.sh 12.6 -RUN bash ./install_magma.sh 12.6 -RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda - -FROM cuda as cuda12.8 -RUN bash ./install_cuda.sh 12.8 -RUN bash ./install_magma.sh 12.8 -RUN ln -sf /usr/local/cuda-12.8 /usr/local/cuda - -FROM cuda as cuda12.9 -RUN bash ./install_cuda.sh 12.9 -RUN bash ./install_magma.sh 12.9 -RUN ln -sf /usr/local/cuda-12.9 /usr/local/cuda - -FROM cuda as cuda13.0 -RUN bash ./install_cuda.sh 13.0 -RUN bash ./install_magma.sh 13.0 -RUN ln -sf /usr/local/cuda-13.0 /usr/local/cuda - -FROM cuda as cuda13.2 -RUN bash ./install_cuda.sh 13.2 -RUN bash ./install_magma.sh 13.2 -RUN ln -sf /usr/local/cuda-13.2 /usr/local/cuda - -# Install libibverbs for libtorch and copy to CUDA directory -RUN apt-get update -y && \ - apt-get install -y libibverbs-dev librdmacm-dev && \ - cp /usr/lib/x86_64-linux-gnu/libmlx5.so* /usr/local/cuda/lib64/ && \ - cp /usr/lib/x86_64-linux-gnu/librdmacm.so* /usr/local/cuda/lib64/ && \ - cp /usr/lib/x86_64-linux-gnu/libibverbs.so* /usr/local/cuda/lib64/ && \ - cp /usr/lib/x86_64-linux-gnu/libnl* /usr/local/cuda/lib64/ - -FROM cpu as rocm -ARG ROCM_VERSION -ARG PYTORCH_ROCM_ARCH -ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} -ENV MKLROOT /opt/intel -# Adding ROCM_PATH env var so that LoadHip.cmake (even with logic updated for ROCm6.0) -# find HIP works for ROCm5.7. Not needed for ROCm6.0 and above. -# Remove below when ROCm5.7 is not in support matrix anymore. -ENV ROCM_PATH /opt/rocm -# No need to install ROCm as base docker image should have full ROCm install -#ADD ./common/install_rocm.sh install_rocm.sh -ADD ./common/install_rocm_drm.sh install_rocm_drm.sh -ADD ./common/install_rocm_magma.sh install_rocm_magma.sh -# gfortran and python needed for building magma from source for ROCm -RUN apt-get update -y && \ - apt-get install gfortran -y && \ - apt-get install python3 python-is-python3 -y && \ - apt-get clean - -RUN bash ./install_rocm_drm.sh /opt/amdgpu && rm install_rocm_drm.sh -RUN bash ./install_rocm_magma.sh ${ROCM_VERSION} && rm install_rocm_magma.sh - -FROM ${BASE_TARGET} as final -COPY --from=openssl /opt/openssl /opt/openssl -# Install patchelf -ADD ./common/install_patchelf.sh install_patchelf.sh -RUN bash ./install_patchelf.sh && rm install_patchelf.sh -# Install Anaconda -COPY --from=conda /opt/conda /opt/conda -# Install python -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH diff --git a/.ci/docker/libtorch/build.sh b/.ci/docker/libtorch/build.sh deleted file mode 100755 index 5bfe70f34347e..0000000000000 --- a/.ci/docker/libtorch/build.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env bash -# Script used only in CD pipeline - -set -eoux pipefail - -image="$1" -shift - -if [ -z "${image}" ]; then - echo "Usage: $0 IMAGENAME:ARCHTAG" - exit 1 -fi - -TOPDIR=$(git rev-parse --show-toplevel) - -DOCKER=${DOCKER:-docker} - -# Go from imagename:tag to tag -DOCKER_TAG_PREFIX=$(echo "${image}" | awk -F':' '{print $2}') - -GPU_ARCH_VERSION="" -if [[ "${DOCKER_TAG_PREFIX}" == cuda* ]]; then - # extract cuda version from image name. e.g. manylinux2_28-builder:cuda12.8 returns 12.8 - GPU_ARCH_VERSION=$(echo "${DOCKER_TAG_PREFIX}" | awk -F'cuda' '{print $2}') -elif [[ "${DOCKER_TAG_PREFIX}" == rocm* ]]; then - # extract rocm version from image name. e.g. manylinux2_28-builder:rocm6.2.4 returns 6.2.4 - GPU_ARCH_VERSION=$(echo "${DOCKER_TAG_PREFIX}" | awk -F'rocm' '{print $2}') -fi - -case ${DOCKER_TAG_PREFIX} in - cpu) - BASE_TARGET=cpu - GPU_IMAGE=ubuntu:20.04 - DOCKER_GPU_BUILD_ARG="" - ;; - cuda*) - BASE_TARGET=cuda${GPU_ARCH_VERSION} - GPU_IMAGE=ubuntu:20.04 - DOCKER_GPU_BUILD_ARG="" - ;; - rocm*) - # we want the patch version of 7.1 instead - if [[ "$GPU_ARCH_VERSION" == *"7.1"* ]]; then - GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.1" - fi - # we want the patch version of 7.0 instead - if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then - GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" - fi - # we want the patch version of 6.4 instead - if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then - GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" - fi - BASE_TARGET=rocm - GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" - DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}" - ;; - *) - echo "ERROR: Unrecognized DOCKER_TAG_PREFIX: ${DOCKER_TAG_PREFIX}" - exit 1 - ;; -esac - -tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') - -DOCKER_BUILDKIT=1 ${DOCKER} build \ - --target final \ - ${DOCKER_GPU_BUILD_ARG} \ - --build-arg "GPU_IMAGE=${GPU_IMAGE}" \ - --build-arg "BASE_TARGET=${BASE_TARGET}" \ - -t "${tmp_tag}" \ - $@ \ - -f "${TOPDIR}/.ci/docker/libtorch/Dockerfile" \ - "${TOPDIR}/.ci/docker/" diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 1cf83acb1c736..1367b004ee8a3 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -84,7 +84,7 @@ RUN cp $(which patchelf) /patchelf FROM patchelf as python # build python -COPY manywheel/build_scripts /build_scripts +COPY manywheel/s390_scripts /build_scripts ADD ./common/install_cpython.sh /build_scripts/install_cpython.sh ENV SSL_CERT_FILE= RUN bash build_scripts/build.sh && rm -r build_scripts diff --git a/.ci/docker/manywheel/build_scripts/manylinux1-check.py b/.ci/docker/manywheel/build_scripts/manylinux1-check.py deleted file mode 100644 index f6b9b9fc2393e..0000000000000 --- a/.ci/docker/manywheel/build_scripts/manylinux1-check.py +++ /dev/null @@ -1,63 +0,0 @@ -# Logic copied from PEP 513 - - -def is_manylinux1_compatible(): - # Only Linux, and only x86-64 / i686 - from distutils.util import get_platform - - if get_platform() not in ["linux-x86_64", "linux-i686", "linux-s390x"]: - return False - - # Check for presence of _manylinux module - try: - import _manylinux - - return bool(_manylinux.manylinux1_compatible) - except (ImportError, AttributeError): - # Fall through to heuristic check below - pass - - # Check glibc version. CentOS 5 uses glibc 2.5. - return have_compatible_glibc(2, 5) - - -def have_compatible_glibc(major, minimum_minor): - import ctypes - - process_namespace = ctypes.CDLL(None) - try: - gnu_get_libc_version = process_namespace.gnu_get_libc_version - except AttributeError: - # Symbol doesn't exist -> therefore, we are not linked to - # glibc. - return False - - # Call gnu_get_libc_version, which returns a string like "2.5". - gnu_get_libc_version.restype = ctypes.c_char_p - version_str = gnu_get_libc_version() - # py2 / py3 compatibility: - if not isinstance(version_str, str): - version_str = version_str.decode("ascii") - - # Parse string and check against requested version. - version = [int(piece) for piece in version_str.split(".")] - if len(version) != 2: - raise AssertionError( - f"Expected version to have 2 components (major.minor), got {len(version)}: {version_str}" - ) - if major != version[0]: - return False - if minimum_minor > version[1]: - return False - return True - - -import sys - - -if is_manylinux1_compatible(): - print(f"{sys.executable} is manylinux1 compatible") - sys.exit(0) -else: - print(f"{sys.executable} is NOT manylinux1 compatible") - sys.exit(1) diff --git a/.ci/docker/manywheel/build_scripts/ssl-check.py b/.ci/docker/manywheel/build_scripts/ssl-check.py deleted file mode 100644 index c4df0eacbb7fd..0000000000000 --- a/.ci/docker/manywheel/build_scripts/ssl-check.py +++ /dev/null @@ -1,26 +0,0 @@ -# cf. https://github.com/pypa/manylinux/issues/53 - -import sys -from urllib.request import urlopen - - -GOOD_SSL = "https://google.com" -BAD_SSL = "https://self-signed.badssl.com" - - -print("Testing SSL certificate checking for Python:", sys.version) - -EXC = OSError - -print(f"Connecting to {GOOD_SSL} should work") -urlopen(GOOD_SSL) -print("...it did, yay.") - -print(f"Connecting to {BAD_SSL} should fail") -try: - urlopen(BAD_SSL) - # If we get here then we failed: - print("...it DIDN'T!!!!!11!!1one!") - sys.exit(1) -except EXC: - print("...it did, yay.") diff --git a/.ci/docker/manywheel/build_scripts/build.sh b/.ci/docker/manywheel/s390_scripts/build.sh similarity index 90% rename from .ci/docker/manywheel/build_scripts/build.sh rename to .ci/docker/manywheel/s390_scripts/build.sh index b6a70f0a72787..13141dfd4ae33 100644 --- a/.ci/docker/manywheel/build_scripts/build.sh +++ b/.ci/docker/manywheel/s390_scripts/build.sh @@ -18,13 +18,7 @@ AUTOCONF_HASH=954bd69b391edc12d6a4a51a2dd1476543da5c6bbf05a95b59dc0dd6fd4c2969 # Dependencies for compiling Python that we want to remove from # the final image after compiling Python -PYTHON_COMPILE_DEPS="zlib-devel bzip2-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel libpcap-devel xz-devel libffi-devel" - -if [ "$(uname -m)" != "s390x" ] ; then - PYTHON_COMPILE_DEPS="${PYTHON_COMPILE_DEPS} db4-devel" -else - PYTHON_COMPILE_DEPS="${PYTHON_COMPILE_DEPS} libdb-devel" -fi +PYTHON_COMPILE_DEPS="zlib-devel bzip2-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel libpcap-devel xz-devel libffi-devel libdb-devel" # Libraries that are allowed as part of the manylinux1 profile MANYLINUX1_DEPS="glibc-devel libstdc++-devel glib2-devel libX11-devel libXext-devel libXrender-devel mesa-libGL-devel libICE-devel libSM-devel ncurses-devel" @@ -103,13 +97,6 @@ find /opt/_internal \ -o \( -type f -a -name '*.pyc' -o -name '*.pyo' \) \ -print0 | xargs -0 rm -f -for PYTHON in /opt/python/*/bin/python; do - # Smoke test to make sure that our Pythons work, and do indeed detect as - # being manylinux compatible: - $PYTHON $MY_DIR/manylinux1-check.py - # Make sure that SSL cert checking works - $PYTHON $MY_DIR/ssl-check.py -done # Fix libc headers to remain compatible with C99 compilers. find /usr/include/ -type f -exec sed -i 's/\bextern _*inline_*\b/extern __inline __attribute__ ((__gnu_inline__))/g' {} + diff --git a/.ci/docker/manywheel/build_scripts/build_utils.sh b/.ci/docker/manywheel/s390_scripts/build_utils.sh similarity index 100% rename from .ci/docker/manywheel/build_scripts/build_utils.sh rename to .ci/docker/manywheel/s390_scripts/build_utils.sh diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 14b8ff59fcfbe..2f3a8b49a0022 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -292,6 +292,11 @@ lintrunner==0.12.11 #Pinned versions: 0.12.11 #test that import: +spin==0.17 +#Description: developer CLI for common build/lint tasks +#Pinned versions: 0.17 +#test that import: + redis>=4.0.0 #Description: redis database #test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py) diff --git a/.ci/pytorch/smoke_test/check_wheel_tags.py b/.ci/pytorch/smoke_test/check_wheel_tags.py new file mode 100644 index 0000000000000..47853183e0270 --- /dev/null +++ b/.ci/pytorch/smoke_test/check_wheel_tags.py @@ -0,0 +1,236 @@ +"""Validate wheel platform tags and macOS dylib minos. + +Supports two modes: +1. Pre-install: reads .whl files from PYTORCH_FINAL_PACKAGE_DIR +2. Post-install: reads metadata from installed torch package (soft warnings) +- (macOS only) dylib minos matches the wheel platform tag +""" + +import os +import platform +import re +import subprocess +import sys +import zipfile +from pathlib import Path + + +EXPECTED_PLATFORM_TAGS: dict[str, str] = { + "linux": r"_x86_64$", + "linux-aarch64": r"_aarch64$", + "windows": r"^win_amd64$", + "win32": r"^win_amd64$", + "macos-arm64": r"^macosx_\d+_\d+_arm64$", + "darwin": r"^macosx_\d+_\d+_(arm64|x86_64)$", +} + + +def _extract_wheel_tags(whl_path: Path) -> list[str]: + """Extract Tag values from the WHEEL metadata file inside a .whl archive.""" + tags = [] + with zipfile.ZipFile(whl_path, "r") as zf: + wheel_files = [n for n in zf.namelist() if n.endswith("/WHEEL")] + if not wheel_files: + return tags + content = zf.read(wheel_files[0]).decode("utf-8") + for line in content.splitlines(): + if line.startswith("Tag:"): + tags.append(line.split(":", 1)[1].strip()) + return tags + + +def _extract_installed_wheel_tags(package: str = "torch") -> list[str]: + """Extract Tag values from an installed package's WHEEL metadata.""" + from importlib.metadata import distribution + + dist = distribution(package) + wheel_text = dist.read_text("WHEEL") + if not wheel_text: + return [] + tags = [] + for line in wheel_text.splitlines(): + if line.startswith("Tag:"): + tags.append(line.split(":", 1)[1].strip()) + return tags + + +def check_wheel_platform_tag() -> None: + """Validate that wheel Tags in WHEEL metadata match the expected platform. + + Mode 1: PYTORCH_FINAL_PACKAGE_DIR set → read .whl file (strict, raises on mismatch) + Mode 2: No wheel dir → read from installed torch package (soft, prints warnings) + """ + wheel_dir = os.getenv("PYTORCH_FINAL_PACKAGE_DIR", "") + + target_os = os.getenv("TARGET_OS", sys.platform) + if target_os == "linux" and platform.machine() == "aarch64": + target_os = "linux-aarch64" + expected_python = f"cp{sys.version_info.major}{sys.version_info.minor}" + abiflags = getattr(sys, "abiflags", "") + expected_abi = f"cp{sys.version_info.major}{sys.version_info.minor}{abiflags}" + + platform_pattern = EXPECTED_PLATFORM_TAGS.get(target_os) + if not platform_pattern: + print( + f"No expected platform pattern for TARGET_OS={target_os}, " + "skipping wheel tag check" + ) + return + + # Mode 1: Read from .whl file + if wheel_dir and os.path.isdir(wheel_dir): + whls = list(Path(wheel_dir).glob("torch-*.whl")) + if not whls: + print(f"No torch wheel found in {wheel_dir}, skipping wheel tag check") + return + if len(whls) > 1: + raise RuntimeError( + f"Expected exactly one torch wheel in {wheel_dir}, " + f"found {len(whls)}: {[w.name for w in whls]}" + ) + whl = whls[0] + print(f"Checking wheel platform tag for: {whl.name}") + tags = _extract_wheel_tags(whl) + source = whl.name + else: + # Mode 2: Read from installed package (soft) + print("PYTORCH_FINAL_PACKAGE_DIR not set, reading from installed torch package") + try: + tags = _extract_installed_wheel_tags("torch") + source = "installed torch" + except Exception as e: + print(f"Could not read installed torch metadata: {e}, skipping") + return + + if not tags: + raise RuntimeError(f"No Tag found in WHEEL metadata of {source}") + + for tag_str in tags: + parts = tag_str.split("-") + if len(parts) != 3: + msg = ( + f"Malformed wheel tag '{tag_str}' in {source}, " + f"expected format: --" + ) + raise RuntimeError(msg) + continue + + python_tag, abi_tag, platform_tag = parts + + print(f"Checking tag: {tag_str} (from {source})") + if python_tag != expected_python: + msg: str = ( + f"Python tag mismatch in {source}: " + f"got '{python_tag}', expected '{expected_python}'" + ) + raise RuntimeError(msg) + + if abi_tag != expected_abi: + msg = ( + f"ABI tag mismatch in {source}: " + f"got '{abi_tag}', expected '{expected_abi}'" + ) + raise RuntimeError(msg) + + if not re.search(platform_pattern, platform_tag): + msg = ( + f"Platform tag mismatch in {source}: " + f"got '{platform_tag}', expected pattern matching " + f"'{platform_pattern}' for TARGET_OS={target_os}" + ) + raise RuntimeError(msg) + + print(f"OK: Wheel tag(s) valid for {source}: {', '.join(tags)}") + + +def check_mac_wheel_minos() -> None: + """Check that dylib minos matches the wheel platform tag on macOS. + + Extracts dylibs from the .whl in PYTORCH_FINAL_PACKAGE_DIR to a temp dir, + then verifies each dylib's minos (from otool -l) matches the platform tag. + """ + if sys.platform != "darwin": + return + + wheel_dir = os.getenv("PYTORCH_FINAL_PACKAGE_DIR", "") + if not wheel_dir or not os.path.isdir(wheel_dir): + print("PYTORCH_FINAL_PACKAGE_DIR not set, skipping wheel minos check") + return + + whls = list(Path(wheel_dir).glob("*.whl")) + if not whls: + print(f"No .whl files in {wheel_dir}, skipping wheel minos check") + return + + import tempfile + + for whl in whls: + print(f"Checking wheel tag minos for: {whl.name}") + + m = re.search(r"macosx_(\d+)_(\d+)_(\w+)\.whl$", whl.name) + if not m: + print(f"No macOS platform tag in {whl.name}, skipping") + continue + + expected_minos = f"{m.group(1)}.{m.group(2)}" + print(f"Expected minos from platform tag: {expected_minos}") + + # Extract dylibs from wheel to temp dir + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(whl, "r") as zf: + dylib_names = [n for n in zf.namelist() if n.endswith(".dylib")] + if not dylib_names: + print("No .dylib files in wheel, skipping minos check") + continue + for name in dylib_names: + zf.extract(name, tmpdir) + + dylibs = list(Path(tmpdir).rglob("*.dylib")) + mismatches = [] + for dylib in dylibs: + try: + result = subprocess.run( + ["otool", "-l", str(dylib)], + capture_output=True, + text=True, + timeout=30, + ) + except Exception: + continue + + minos = None + lines = result.stdout.splitlines() + for i, line in enumerate(lines): + s = line.strip() + if "LC_BUILD_VERSION" in s: + for j in range(i + 1, min(i + 6, len(lines))): + if lines[j].strip().startswith("minos"): + minos = lines[j].strip().split()[1] + break + break + if "LC_VERSION_MIN_MACOSX" in s: + for j in range(i + 1, min(i + 4, len(lines))): + if lines[j].strip().startswith("version"): + minos = lines[j].strip().split()[1] + break + break + + if minos and minos != expected_minos: + mismatches.append( + f"{dylib.name}: minos={minos}, expected={expected_minos}" + ) + + if mismatches: + raise RuntimeError( + f"minos/platform tag mismatch in {len(mismatches)} dylib(s):\n" + + "\n".join(f" {m}" for m in mismatches) + ) + print( + f"OK: All {len(dylibs)} dylib(s) have minos matching " + f"platform tag ({expected_minos})" + ) + + +if __name__ == "__main__": + check_wheel_platform_tag() + check_mac_wheel_minos() diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index eb5e9baaee88d..877218e71e307 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -10,6 +10,8 @@ from pathlib import Path from tempfile import NamedTemporaryFile +from check_wheel_tags import check_mac_wheel_minos, check_wheel_platform_tag + import torch import torch._dynamo import torch.nn as nn @@ -637,6 +639,9 @@ def main() -> None: smoke_test_nvshmem() + check_wheel_platform_tag() + check_mac_wheel_minos() + if __name__ == "__main__": main() diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index c7310472f6d25..a82dc24f8fd33 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -14,9 +14,10 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" # shellcheck source=./common-build.sh source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" -# Do not change workspace permissions for ROCm and s390x CI jobs -# as it can leave workspace with bad permissions for cancelled jobs -if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then +# Only change workspace permissions if passwordless sudo is available +# (e.g. ROCm and s390x CI jobs lack it, and changing permissions +# can leave the workspace in a bad state for cancelled jobs) +if sudo -n true 2>/dev/null && [[ -d /var/lib/jenkins/workspace ]]; then # Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96) WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace") cleanup_workspace() { @@ -360,6 +361,7 @@ test_python_smoke_b200() { inductor/test_torchinductor \ inductor/test_nv_universal_gemm \ inductor/test_fused_attention \ + test_varlen_attention \ $PYTHON_TEST_EXTRA_OPTION \ --upload-artifacts-while-running assert_git_not_dirty diff --git a/.github/actions/setup-xpu/action.yml b/.github/actions/setup-xpu/action.yml index 740492475d6e2..dfae0a358259e 100644 --- a/.github/actions/setup-xpu/action.yml +++ b/.github/actions/setup-xpu/action.yml @@ -65,3 +65,6 @@ runs: # Add render group for container creation. render_gid=`cat /etc/group | grep render | cut -d: -f3` echo "GPU_FLAG=--device=/dev/mem --device=/dev/dri --group-add video --group-add $render_gid" >> "${GITHUB_ENV}" + + - name: Login to ECR + uses: ./.github/actions/ecr-login diff --git a/.github/arc.yaml b/.github/arc.yaml new file mode 100644 index 0000000000000..0a74e2826ab0b --- /dev/null +++ b/.github/arc.yaml @@ -0,0 +1,109 @@ +# ARC (Actions Runner Controller) Runner Label Mapping +# +# Maps current GitHub Actions runner labels to new ARC runner labels. +# Reference: https://github.com/pytorch/ci-infra/issues/396 +# +# New label format: +# {os}-[b]{arch}{vendor}{features}-{vcpu}-{memory}[-{gpu_type}[-{gpu_count}]] +# +# Fields: +# os - l=Linux, w=Windows, m=MacOS +# b - (optional) bare-metal instance +# arch - x86=x86_64, arm64=AArch64 +# vendor - i=Intel, a=AMD, g2/g3/g4=Graviton gen +# features - (x86 only) avx2, avx512, amx +# vcpu - vCPU count +# memory - RAM in GiB +# gpu_type - (optional) t4, a10g, l4 +# gpu_count- (optional, omitted when 1) +# +# Entries marked "# upgraded" had no exact ARC equivalent and were mapped to +# the next larger available runner. + +runner_mapping: + + # ---- x86 CPU — Intel AVX-512 (c5, c7i families) ---- + + linux.large: l-x86iavx512-2-4 # c5.large + linux.c7i.large: l-x86iavx512-2-4 # c7i.large + linux.2xlarge: l-x86iavx512-8-16 # c5.2xlarge + linux.c7i.2xlarge: l-x86iavx512-8-16 # c7i.2xlarge + linux.4xlarge: l-x86iavx512-16-32 # c5.4xlarge + linux.4xlarge.for.testing.donotuse: l-x86iavx512-16-32 # c5.4xlarge + linux.c7i.4xlarge: l-x86iavx512-16-32 # c7i.4xlarge + linux.c7i.8xlarge: l-x86iavx512-48-96 # c7i.8xlarge — upgraded (no 32-64 equivalent) + linux.9xlarge.ephemeral: l-x86iavx512-36-72 # c5.9xlarge + linux.12xlarge: l-x86iavx512-48-96 # c5.12xlarge + linux.12xlarge.ephemeral: l-x86iavx512-48-96 # c5.12xlarge + linux.c7i.12xlarge: l-x86iavx512-48-96 # c7i.12xlarge + linux.16xlarge.spr: l-x86iavx512-94-192 # c7i.16xlarge — upgraded (no 64-128 equivalent) + linux.24xlarge: l-x86iavx512-94-192 # c5.24xlarge + linux.24xlarge.ephemeral: l-x86iavx512-94-192 # c5.24xlarge + linux.c7i.24xlarge: l-x86iavx512-94-192 # c7i.24xlarge + linux.24xl.spr-metal: l-bx86iamx-94-192 # c7i.metal-24xl + + # ---- x86 CPU — Intel AMX (m7i-flex family) ---- + + linux.2xlarge.amx: l-x86iamx-8-32 # m7i-flex.2xlarge + linux.4xlarge.amx: l-x86iamx-32-128 # m7i-flex.4xlarge — upgraded (no 16-64 equivalent) + linux.8xlarge.amx: l-x86iamx-32-128 # m7i-flex.8xlarge + + # ---- x86 CPU — Intel AVX2 (m4 family) ---- + + linux.2xlarge.avx2: l-x86iavx2-8-32 # m4.2xlarge + linux.4xlarge.avx2: l-x86iavx2-40-160 # m4.4xlarge — upgraded (no 16-64 equivalent) + linux.10xlarge.avx2: l-x86iavx2-40-160 # m4.10xlarge + + # ---- x86 CPU — Memory-optimized (r5, r7i families) ---- + + linux.r7i.large: l-x86iavx512-8-64 # r7i.large — upgraded (no 2-16 equivalent) + linux.r7i.xlarge: l-x86iavx512-8-64 # r7i.xlarge — upgraded (no 4-32 equivalent) + linux.r7i.2xlarge: l-x86iavx512-8-64 # r7i.2xlarge + linux.r7i.4xlarge: l-x86iavx512-16-128 # r7i.4xlarge + linux.r7i.8xlarge: l-x86iavx512-32-256 # r7i.8xlarge + linux.r7i.12xlarge: l-x86iavx512-48-384 # r7i.12xlarge + linux.2xlarge.memory: l-x86iavx512-8-64 # r5.2xlarge + linux.4xlarge.memory: l-x86iavx512-16-128 # r5.4xlarge + linux.8xlarge.memory: l-x86iavx512-32-256 # r5.8xlarge + linux.12xlarge.memory: l-x86iavx512-48-384 # r5.12xlarge + linux.12xlarge.memory.ephemeral: l-x86iavx512-48-384 # r5.12xlarge + linux.16xlarge.memory: l-x86iavx512-94-768 # r5.16xlarge — upgraded (no 64-512 equivalent) + linux.24xlarge.memory: l-x86iavx512-94-768 # r5.24xlarge + + # ---- x86 CPU — AMD (m6a, m7a families) ---- + + linux.8xlarge.amd: l-x86aavx512-94-384 # m7a.8xlarge — upgraded (no 32-128 equivalent) + linux.12xlarge.amd: l-x86aavx512-94-384 # m6a.12xlarge — upgraded (no 48-192 equivalent) + linux.24xlarge.amd: l-x86aavx512-94-384 # m7a.24xlarge + + # ---- x86 GPU — T4 (g4dn family) ---- + + linux.g4dn.4xlarge.nvidia.gpu: l-x86iavx512-16-64-t4 # g4dn.4xlarge + linux.g4dn.12xlarge.nvidia.gpu: l-x86iavx512-48-192-t4-4 # g4dn.12xlarge + linux.g4dn.metal.nvidia.gpu: l-bx86iavx512-94-384-t4-8 # g4dn.metal + + # ---- x86 GPU — A10G (g5 family) ---- + + linux.g5.4xlarge.nvidia.gpu: l-x86aavx2-16-64-a10g # g5.4xlarge + linux.g5.12xlarge.nvidia.gpu: l-x86aavx2-48-192-a10g-4 # g5.12xlarge + linux.g5.48xlarge.nvidia.gpu: l-x86aavx2-192-768-a10g-8 # g5.48xlarge + + # ---- x86 GPU — L4 (g6 family) ---- + + linux.g6.4xlarge.experimental.nvidia.gpu: l-x86aavx2-16-64-l4 # g6.4xlarge + linux.g6.12xlarge.nvidia.gpu: l-x86aavx2-48-192-l4-4 # g6.12xlarge + + # ---- x86 GPU — V100 (p3 family) ---- + + linux.p3.8xlarge.nvidia.gpu: l-x86aavx2-48-192-a10g-4 # p3.8xlarge — upgraded (no V100 equivalent; 4x A10G closest match) + + # ---- ARM64 — Graviton ---- + + linux.arm64.2xlarge: l-arm64g2-6-32 # t4g.2xlarge + linux.arm64.2xlarge.ephemeral: l-arm64g2-6-32 # t4g.2xlarge + linux.arm64.m7g.4xlarge: l-arm64g3-16-64 # m7g.4xlarge + linux.arm64.m7g.4xlarge.ephemeral: l-arm64g3-16-64 # m7g.4xlarge + linux.arm64.m8g.4xlarge: l-arm64g4-16-64 # m8g.4xlarge + linux.arm64.m8g.4xlarge.ephemeral: l-arm64g4-16-64 # m8g.4xlarge + linux.arm64.r7g.12xlarge.memory: l-arm64g3-48-384 # r7g.12xlarge + linux.arm64.m7g.metal: l-barm64g3-62-256 # m7g.metal diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 8acb4efa1d8e9..cd9a718f496c7 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -24,7 +24,7 @@ CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0", "13.2"] -CUDA_STABLE = "12.8" +CUDA_STABLE = "13.0" CUDA_ARCHES_FULL_VERSION = { "12.6": "12.6.3", "12.8": "12.8.1", @@ -288,12 +288,6 @@ def arch_type(arch_version: str) -> str: RELEASE = "release" DEBUG = "debug" -LIBTORCH_CONTAINER_IMAGES: dict[str, str] = { - **{gpu_arch: f"libtorch-cxx11-builder:cuda{gpu_arch}" for gpu_arch in CUDA_ARCHES}, - **{gpu_arch: f"libtorch-cxx11-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES}, - "cpu": "libtorch-cxx11-builder:cpu", -} - FULL_PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"] @@ -321,10 +315,7 @@ def generate_libtorch_matrix( ) -> list[dict[str, str]]: if arches is None: arches = ["cpu"] - if os == "linux": - arches += CUDA_ARCHES - arches += ROCM_ARCHES - elif os == "windows": + if os == "windows": # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up # in 2.10 windows_cuda_arches = CUDA_ARCHES.copy() @@ -344,9 +335,6 @@ def generate_libtorch_matrix( for libtorch_variant in libtorch_variants: gpu_arch_type = arch_type(arch_version) gpu_arch_version = "" if arch_version == "cpu" else arch_version - # ROCm builds without-deps failed even in ROCm runners; skip for now - if gpu_arch_type == "rocm" and ("without-deps" in libtorch_variant): - continue ret.append( { "gpu_arch_type": gpu_arch_type, @@ -356,16 +344,8 @@ def generate_libtorch_matrix( ), "libtorch_config": release_type, "libtorch_variant": libtorch_variant, - "container_image": ( - LIBTORCH_CONTAINER_IMAGES[arch_version].split(":")[0] - if os not in ("windows", "windows-arm64") - else "" - ), - "container_image_tag_prefix": ( - LIBTORCH_CONTAINER_IMAGES[arch_version].split(":")[1] - if os not in ("windows", "windows-arm64") - else "" - ), + "container_image": "", + "container_image_tag_prefix": "", "package_type": "libtorch", "build_name": f"libtorch-{gpu_arch_type}{gpu_arch_version}-{libtorch_variant}-{release_type}".replace( ".", "_" diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index baf560234549b..da57e5f668d3b 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -79,13 +79,18 @@ GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" +GH_OUTPUT_KEY_USE_ARC = "use-arc" OPT_OUT_LABEL = "no-runner-experiments" SETTING_EXPERIMENTS = "experiments" LF_FLEET_EXPERIMENT = "lf" +ARC_FLEET_EXPERIMENT = "arc" CANARY_FLEET_SUFFIX = ".c" +ARC_LABEL_PREFIX = "mt-" +ARC_CANARY_LABEL_PREFIX = "c-" + class Experiment(NamedTuple): rollout_perc: float = ( @@ -101,6 +106,11 @@ class Experiment(NamedTuple): # Add more fields as needed +class RunnerPrefixResult(NamedTuple): + prefix: str + use_arc: bool = False + + class Settings(NamedTuple): """ Settings for the experiments that can be opted into. @@ -439,12 +449,13 @@ def get_runner_prefix( eligible_experiments: frozenset[str] = frozenset(), opt_out_experiments: frozenset[str] = frozenset(), is_canary: bool = False, -) -> str: +) -> RunnerPrefixResult: settings = parse_settings(rollout_state) user_optins = parse_users(rollout_state) fleet_prefix = "" prefixes = [] + use_arc = False for experiment_name, experiment_settings in settings.experiments.items(): if not experiment_settings.all_branches and is_exception_branch(branch): log.info( @@ -510,7 +521,12 @@ def get_runner_prefix( if enabled: label = experiment_name - if experiment_name == LF_FLEET_EXPERIMENT: + if experiment_name == ARC_FLEET_EXPERIMENT: + use_arc = True + log.info( + f"ARC experiment enabled. Using ARC runner prefix ({'canary' if is_canary else 'production'})." + ) + elif experiment_name == LF_FLEET_EXPERIMENT: # We give some special treatment to the "lf" experiment since determines the fleet we use # - If it's enabled, then we always list it's prefix first # - If we're in the canary branch, then we append ".c" to the lf prefix @@ -520,6 +536,15 @@ def get_runner_prefix( else: prefixes.append(label) + # ARC experiment takes precedence: return a fixed label prefix + if use_arc: + arc_prefix = ( + ARC_CANARY_LABEL_PREFIX + ARC_LABEL_PREFIX + if is_canary + else ARC_LABEL_PREFIX + ) + return RunnerPrefixResult(prefix=arc_prefix, use_arc=True) + if len(prefixes) > 1: log.error( f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" @@ -530,7 +555,8 @@ def get_runner_prefix( if fleet_prefix: prefixes.insert(0, fleet_prefix) - return ".".join(prefixes) + "." if prefixes else "" + prefix = ".".join(prefixes) + "." if prefixes else "" + return RunnerPrefixResult(prefix=prefix) def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -619,7 +645,7 @@ def main() -> None: is_canary = args.github_repo == "pytorch/pytorch-canary" - runner_label_prefix = get_runner_prefix( + result = get_runner_prefix( rollout_state, (args.github_issue_owner, username), args.github_branch, @@ -627,6 +653,8 @@ def main() -> None: args.opt_out_experiments, is_canary, ) + runner_label_prefix = result.prefix + set_github_output(GH_OUTPUT_KEY_USE_ARC, str(result.use_arc).lower()) except Exception as e: log.error( diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py index e8f9f1b8b4aa6..47d54af377f45 100644 --- a/.github/scripts/test_runner_determinator.py +++ b/.github/scripts/test_runner_determinator.py @@ -183,8 +183,8 @@ def test_opted_in_user(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for User1") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for User1") def test_explicitly_opted_out_user(self) -> None: settings_text = """ @@ -200,8 +200,8 @@ def test_explicitly_opted_out_user(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for User1") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for User1") def test_explicitly_opted_in_and_out_user_should_opt_out(self) -> None: settings_text = """ @@ -217,8 +217,8 @@ def test_explicitly_opted_in_and_out_user_should_opt_out(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for User1") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for User1") def test_opted_in_user_two_experiments(self) -> None: settings_text = """ @@ -234,8 +234,10 @@ def test_opted_in_user_two_experiments(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + result = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for User2" + ) def test_opted_in_user_two_experiments_default(self) -> None: settings_text = """ @@ -252,8 +254,8 @@ def test_opted_in_user_two_experiments_default(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for User2") + result = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for User2") def test_opted_in_user_two_experiments_default_exp(self) -> None: settings_text = """ @@ -270,10 +272,12 @@ def test_opted_in_user_two_experiments_default_exp(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix( + result = rd.get_runner_prefix( settings_text, ["User2"], USER_BRANCH, frozenset(["lf", "otherExp"]) ) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for User2" + ) def test_opted_in_user_two_experiments_default_exp_2(self) -> None: settings_text = """ @@ -290,10 +294,12 @@ def test_opted_in_user_two_experiments_default_exp_2(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix( + result = rd.get_runner_prefix( settings_text, ["User2"], USER_BRANCH, frozenset(["otherExp"]) ) - self.assertEqual("otherExp.", prefix, "Runner prefix not correct for User2") + self.assertEqual( + "otherExp.", result.prefix, "Runner prefix not correct for User2" + ) @patch("random.uniform", return_value=50) def test_opted_out_user(self, mock_uniform: Mock) -> None: @@ -310,8 +316,8 @@ def test_opted_out_user(self, mock_uniform: Mock) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") @patch("random.uniform", return_value=10) def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> None: @@ -330,8 +336,10 @@ def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> No """ # User3 is opted out, but is pulled into both experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) @patch("random.uniform", return_value=10) def test_opted_out_user_was_pulled_in_by_rollout_excl_nondefault( @@ -353,8 +361,8 @@ def test_opted_out_user_was_pulled_in_by_rollout_excl_nondefault( """ # User3 is opted out, but is pulled into default experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") @patch("random.uniform", return_value=10) def test_opted_out_user_was_pulled_in_by_rollout_filter_exp( @@ -376,10 +384,12 @@ def test_opted_out_user_was_pulled_in_by_rollout_filter_exp( """ # User3 is opted out, but is pulled into default experiments by the 10% rollout - prefix = rd.get_runner_prefix( + result = rd.get_runner_prefix( settings_text, ["User3"], USER_BRANCH, frozenset(["otherExp"]) ) - self.assertEqual("otherExp.", prefix, "Runner prefix not correct for user") + self.assertEqual( + "otherExp.", result.prefix, "Runner prefix not correct for user" + ) @patch("random.uniform", return_value=25) def test_opted_out_user_was_pulled_out_by_rollout_filter_exp( @@ -401,8 +411,8 @@ def test_opted_out_user_was_pulled_out_by_rollout_filter_exp( """ # User3 is opted out, but is pulled into default experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") def test_lf_prefix_always_comes_first(self) -> None: settings_text = """ @@ -419,8 +429,10 @@ def test_lf_prefix_always_comes_first(self) -> None: """ - prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) def test_ignores_commented_users(self) -> None: settings_text = """ @@ -437,8 +449,8 @@ def test_ignores_commented_users(self) -> None: """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") def test_ignores_extra_experiments(self) -> None: settings_text = """ @@ -456,8 +468,10 @@ def test_ignores_extra_experiments(self) -> None: """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) def test_disables_experiment_on_exception_branches_when_not_explicitly_opted_in( self, @@ -473,8 +487,8 @@ def test_disables_experiment_on_exception_branches_when_not_explicitly_opted_in( """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") def test_allows_experiment_on_exception_branches_when_explicitly_opted_in( self, @@ -491,8 +505,136 @@ def test_allows_experiment_on_exception_branches_when_explicitly_opted_in( """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") + + +class TestRunnerDeterminatorArcExperiment(TestCase): + ARC_SETTINGS = """ + experiments: + arc: + rollout_perc: 0 + --- + + Users: + @User1,arc + @User2,lf + + """ + + def test_arc_opted_in_user_returns_mt_prefix(self) -> None: + result = rd.get_runner_prefix(self.ARC_SETTINGS, ["User1"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + def test_arc_opted_in_user_canary_returns_c_mt_prefix(self) -> None: + result = rd.get_runner_prefix( + self.ARC_SETTINGS, ["User1"], USER_BRANCH, is_canary=True + ) + self.assertEqual("c-mt-", result.prefix) + self.assertTrue(result.use_arc) + + def test_arc_not_enabled_returns_use_arc_false(self) -> None: + result = rd.get_runner_prefix(self.ARC_SETTINGS, ["User2"], USER_BRANCH) + self.assertFalse(result.use_arc) + + def test_arc_not_enabled_no_experiments_returns_use_arc_false(self) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 0 + --- + + Users: + @User1,arc + + """ + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + @patch("random.uniform", return_value=10) + def test_arc_rollout_percentage(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 25 + --- + + Users: + + """ + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + @patch("random.uniform", return_value=50) + def test_arc_rollout_percentage_not_selected(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 25 + --- + + Users: + + """ + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + def test_arc_opted_out_user(self) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 100 + --- + + Users: + @User1,-arc + + """ + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + def test_arc_exception_branch_not_enabled(self) -> None: + result = rd.get_runner_prefix(self.ARC_SETTINGS, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + def test_arc_exception_branch_all_branches(self) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 0 + all_branches: true + --- + + Users: + @User1,arc + + """ + result = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + def test_arc_takes_precedence_over_lf(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + arc: + rollout_perc: 0 + --- + + Users: + @User1,lf,arc + + """ + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) if __name__ == "__main__": diff --git a/.github/workflows/_link_check_osdc.yml b/.github/workflows/_link_check_osdc.yml new file mode 100644 index 0000000000000..d170d6359cb82 --- /dev/null +++ b/.github/workflows/_link_check_osdc.yml @@ -0,0 +1,52 @@ +on: + workflow_call: + inputs: + runner: + type: string + required: true + ref: + type: string + required: true + +jobs: + lint-urls: + if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }} + uses: ./.github/workflows/_lint.yml + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + ./scripts/lint_urls.sh $( + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" + else + echo "${{ github.event.before }}" "${{ github.sha }}" + fi + ) || { + echo + echo "URL lint failed." + echo "If this is a transient outage, you can bypass it by adding the \`skip-url-lint\` label to your PR." + echo "Or add \`@lint-ignore\` somewhere on the same line as the URL you want to skip checking." + exit 1 + } + + lint-xrefs: + if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }} + uses: ./.github/workflows/_lint.yml + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + ./scripts/lint_xrefs.sh $( + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" + else + echo "${{ github.event.before }}" "${{ github.sha }}" + fi + ) || { + echo + echo "Xref lint failed." + echo "If this is a transient outage, you can bypass it by adding the \`skip-xref-lint\` label to your PR." + echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking." + exit 1 + } diff --git a/.github/workflows/_lint.yml b/.github/workflows/_lint.yml new file mode 100644 index 0000000000000..8e707c748f582 --- /dev/null +++ b/.github/workflows/_lint.yml @@ -0,0 +1,82 @@ +name: Run linters + +on: + workflow_call: + inputs: + runner: + required: true + type: string + description: The runner to use + docker-image: + required: true + type: string + description: The Docker image to use + script: + required: true + type: string + description: The linter script to run + +jobs: + lint: + runs-on: ${{ inputs.runner }} + container: + image: ${{ inputs.docker-image }} + timeout-minutes: 120 + steps: + - name: Fix Git ownership + shell: bash + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + fetch-depth: 0 + submodules: true + no-sudo: true + checkout-mode: treeless + + - name: Setup uv + uses: pytorch/test-infra/.github/actions/setup-uv@main + with: + python-version: "3.12" + activate-environment: true + + - name: Install pip requirements + shell: bash + run: | + set -eux + uv pip install -r .ci/docker/requirements-ci.txt + + - name: Install system requirements + shell: bash + run: | + set -eux + # Update repository + dnf install -y doxygen graphviz nodejs npm + + - name: Install Node.js packages + shell: bash + run: | + set -eux + npm install -g markdown-toc + + - name: Prepare lintrunner + shell: bash + run: | + set -eux + lintrunner init + + - name: Run linter + shell: bash + env: + SCRIPT: ${{ inputs.script }} + run: | + { + echo "#!/usr/bin/env bash"; + echo "set -eou pipefail"; + echo "${SCRIPT}"; + } > "${RUNNER_TEMP}/linter_script" + + # Execute the linter script + bash "${RUNNER_TEMP}/linter_script" diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 0d674f044ec42..ec3d05c316c83 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -41,6 +41,9 @@ on: label-type: description: Type of runners to use value: ${{ jobs.runner-determinator.outputs.label-type }} + use-arc: + description: Whether to use ARC runners + value: ${{ jobs.runner-determinator.outputs.use-arc }} jobs: runner-determinator: @@ -49,6 +52,7 @@ jobs: runs-on: ubuntu-latest outputs: label-type: ${{ steps.set-condition.outputs.label-type }} + use-arc: ${{ steps.set-condition.outputs.use-arc }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ISSUE_NUMBER: ${{ inputs.issue_number }} @@ -58,12 +62,6 @@ jobs: OPT_OUT_EXPERIMENTS: ${{ inputs.opt_out_experiments }} PR_NUMBER: ${{ github.event.pull_request.number }} steps: - # - name: Checkout PyTorch - # uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - # with: - # fetch-depth: 1 - # submodules: true - # TODO: Remove the hardcoded step below # Hardcoding below is temporary for testing ALI runners # This file below should match the script found in .github/scripts/runner_determinator.py @@ -152,13 +150,18 @@ jobs: GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" + GH_OUTPUT_KEY_USE_ARC = "use-arc" OPT_OUT_LABEL = "no-runner-experiments" SETTING_EXPERIMENTS = "experiments" LF_FLEET_EXPERIMENT = "lf" + ARC_FLEET_EXPERIMENT = "arc" CANARY_FLEET_SUFFIX = ".c" + ARC_LABEL_PREFIX = "mt-" + ARC_CANARY_LABEL_PREFIX = "c-" + class Experiment(NamedTuple): rollout_perc: float = ( @@ -174,6 +177,11 @@ jobs: # Add more fields as needed + class RunnerPrefixResult(NamedTuple): + prefix: str + use_arc: bool = False + + class Settings(NamedTuple): """ Settings for the experiments that can be opted into. @@ -335,7 +343,12 @@ jobs: """ Branches that get opted out of experiments by default, until they're explicitly enabled. """ - return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} + return branch.split("/", maxsplit=1)[0] in { + "main", + "nightly", + "release", + "landchecks", + } def load_yaml(yaml_text: str) -> Any: @@ -507,12 +520,13 @@ jobs: eligible_experiments: frozenset[str] = frozenset(), opt_out_experiments: frozenset[str] = frozenset(), is_canary: bool = False, - ) -> str: + ) -> RunnerPrefixResult: settings = parse_settings(rollout_state) user_optins = parse_users(rollout_state) fleet_prefix = "" prefixes = [] + use_arc = False for experiment_name, experiment_settings in settings.experiments.items(): if not experiment_settings.all_branches and is_exception_branch(branch): log.info( @@ -578,7 +592,12 @@ jobs: if enabled: label = experiment_name - if experiment_name == LF_FLEET_EXPERIMENT: + if experiment_name == ARC_FLEET_EXPERIMENT: + use_arc = True + log.info( + f"ARC experiment enabled. Using ARC runner prefix ({'canary' if is_canary else 'production'})." + ) + elif experiment_name == LF_FLEET_EXPERIMENT: # We give some special treatment to the "lf" experiment since determines the fleet we use # - If it's enabled, then we always list it's prefix first # - If we're in the canary branch, then we append ".c" to the lf prefix @@ -588,6 +607,15 @@ jobs: else: prefixes.append(label) + # ARC experiment takes precedence: return a fixed label prefix + if use_arc: + arc_prefix = ( + ARC_CANARY_LABEL_PREFIX + ARC_LABEL_PREFIX + if is_canary + else ARC_LABEL_PREFIX + ) + return RunnerPrefixResult(prefix=arc_prefix, use_arc=True) + if len(prefixes) > 1: log.error( f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" @@ -598,7 +626,8 @@ jobs: if fleet_prefix: prefixes.insert(0, fleet_prefix) - return ".".join(prefixes) + "." if prefixes else "" + prefix = ".".join(prefixes) + "." if prefixes else "" + return RunnerPrefixResult(prefix=prefix) def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -687,7 +716,7 @@ jobs: is_canary = args.github_repo == "pytorch/pytorch-canary" - runner_label_prefix = get_runner_prefix( + result = get_runner_prefix( rollout_state, (args.github_issue_owner, username), args.github_branch, @@ -695,6 +724,8 @@ jobs: args.opt_out_experiments, is_canary, ) + runner_label_prefix = result.prefix + set_github_output(GH_OUTPUT_KEY_USE_ARC, str(result.use_arc).lower()) except Exception as e: log.error( diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index 5724403e6de44..bd6e0a6fd2563 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -91,9 +91,6 @@ jobs: - name: Setup XPU uses: ./.github/actions/setup-xpu - - name: Login to ECR - uses: ./.github/actions/ecr-login - - name: Calculate docker image id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml deleted file mode 100644 index 47bf15e1db3ab..0000000000000 --- a/.github/workflows/build-libtorch-images.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: Build libtorch docker images - -on: - push: - branches: - - main - - release/* - tags: - # NOTE: Binary build pipelines should only get triggered on release candidate or nightly builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - paths: - - .ci/docker/** - - .github/workflows/build-libtorch-images.yml - - .github/actions/binary-docker-build/** - pull_request: - paths: - - .ci/docker/** - - .github/workflows/build-libtorch-images.yml - - .github/actions/binary-docker-build/** - -env: - DOCKER_REGISTRY: "docker.io" - DOCKER_BUILDKIT: 1 - WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) }} - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - get-label-type: - if: github.repository_owner == 'pytorch' - name: get-label-type - uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main - with: - triggering_actor: ${{ github.triggering_actor }} - issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} - curr_branch: ${{ github.head_ref || github.ref_name }} - curr_ref_type: ${{ github.ref_type }} - - build: - environment: ${{ (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) && 'docker-build') || '' }} - needs: get-label-type - runs-on: ${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral - name: libtorch-cxx11-builder:${{ matrix.tag }} - strategy: - fail-fast: false - matrix: - include: [ - { tag: "cuda13.0" }, - { tag: "cuda12.9" }, - { tag: "cuda12.8" }, - { tag: "cuda12.6" }, - { tag: "rocm7.0" }, - { tag: "rocm7.1" }, - { tag: "rocm7.2" }, - { tag: "cpu" }, - ] - steps: - - name: Build docker image - uses: pytorch/pytorch/.github/actions/binary-docker-build@main - with: - docker-image-name: libtorch-cxx11-builder - custom-tag-prefix: ${{ matrix.tag }} - docker-build-dir: libtorch - DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }} - DOCKER_ID: ${{ secrets.DOCKER_ID }} diff --git a/.github/workflows/lint-osdc.yml b/.github/workflows/lint-osdc.yml new file mode 100644 index 0000000000000..1363bddf48dfe --- /dev/null +++ b/.github/workflows/lint-osdc.yml @@ -0,0 +1,373 @@ +name: Lint OSDC (unstable) +# Unstable version of the lint workflow running on k8s ARC runners. +# Runs in parallel with the existing lint workflow during rollout. + +on: + pull_request: + branches-ignore: + - nightly + push: + branches: + - main + - release/* + - landchecks/* + tags: + - ciflow/pull/* + - ciflow/trunk/* + workflow_dispatch: + +permissions: read-all +# The names of steps that actually test the code should be suffixed with `(nonretryable)`. +# When any other step fails, it's job will be retried once by retryBot. +jobs: + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + + get-changed-files: + if: github.repository_owner == 'pytorch' + name: Get changed files + uses: ./.github/workflows/_get-changed-files.yml + with: + all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') || github.event_name == 'push' }} + + lintrunner-clang: + uses: ./.github/workflows/_lint.yml + # Needed to prevent deduping on HUD + name: lintrunner-clang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} + needs: [get-label-type, get-changed-files] + # Only run if there are changed files relevant to clangtidy / clangformat + if: | + github.repository_owner == 'pytorch' && ( + needs.get-changed-files.outputs.changed-files == '*' || + contains(needs.get-changed-files.outputs.changed-files, '.h') || + contains(needs.get-changed-files.outputs.changed-files, '.cpp') || + contains(needs.get-changed-files.outputs.changed-files, '.cc') || + contains(needs.get-changed-files.outputs.changed-files, '.cxx') || + contains(needs.get-changed-files.outputs.changed-files, '.hpp') || + contains(needs.get-changed-files.outputs.changed-files, '.hxx') || + contains(needs.get-changed-files.outputs.changed-files, '.cu') || + contains(needs.get-changed-files.outputs.changed-files, '.cuh') || + contains(needs.get-changed-files.outputs.changed-files, '.mm') || + contains(needs.get-changed-files.outputs.changed-files, '.metal') + ) + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cuda-x86_64-67eb930 + script: | + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + if [ "$CHANGED_FILES" = "*" ]; then + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" + else + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT $CHANGED_FILES" + fi + export CLANG=1 + .github/scripts/lintrunner.sh + + # NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes + # fails to find types when it should + # NOTE: We should be able to disable this and consolidate with Pyrefly + lintrunner-pyrefly: + uses: ./.github/workflows/_lint.yml + name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} + needs: [get-label-type, get-changed-files] + # Only run if there are changed files relevant to pyrefly + if: | + github.repository_owner == 'pytorch' && ( + needs.get-changed-files.outputs.changed-files == '*' || + contains(needs.get-changed-files.outputs.changed-files, '.py') || + contains(needs.get-changed-files.outputs.changed-files, '.pyi') + ) + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + echo "Running pyrefly" + ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh + + lintrunner-noclang: + uses: ./.github/workflows/_lint.yml + name: lintrunner-noclang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} + needs: [get-label-type, get-changed-files] + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + echo "Running all other linters" + if [ "$CHANGED_FILES" = '*' ]; then + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGTIDY_EXECUTORCH_COMPATIBILITY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh + else + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGTIDY_EXECUTORCH_COMPATIBILITY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh + fi + + quick-checks: + if: github.repository_owner == 'pytorch' + needs: get-label-type + uses: ./.github/workflows/_lint.yml + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + # Ensure no non-breaking spaces + # NB: We use 'printf' below rather than '\u000a' since bash pre-4.2 + # does not support the '\u000a' syntax (which is relevant for local linters) + (! git --no-pager grep -In "$(printf '\xC2\xA0')" -- . || (echo "The above lines have non-breaking spaces (U+00A0); please convert them to spaces (U+0020)"; false)) + + # Ensure cross-OS compatible file names + (! git ls-files | grep -E '([<>:"|?*]|[ .]$)' || (echo "The above file names are not valid across all operating systems. Please ensure they don't contain the characters '<>:""|?*' and don't end with a white space or a '.' "; false)) + + # Ensure no versionless Python shebangs + (! git --no-pager grep -In '#!.*python$' -- . || (echo "The above lines have versionless Python shebangs; please specify either python2 or python3"; false)) + + # Ensure ciflow tags mentioned in config + python3 .github/scripts/collect_ciflow_labels.py --validate-tags + + # C++ docs check + pushd docs/cpp/source + ./check-doxygen.sh + popd + + # CUDA kernel launch check + set -eux + python3 torch/testing/_internal/check_kernel_launches.py |& tee cuda_kernel_launch_checks.txt + + pr-sanity-checks: + name: pr-sanity-checks + runs-on: linux.24_04.4x + # Only run this on pull requests. This check is simple enough to be done without a Docker image + if: ${{ github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') && github.repository_owner == 'pytorch' }} + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + submodules: false + fetch-depth: -1 + + - name: PR size check (nonretryable) + env: + BASE: ${{ github.event.pull_request.base.sha }} + HEAD: ${{ github.event.pull_request.head.sha }} + run: | + bash .github/scripts/pr-sanity-check.sh + + workflow-checks: + if: github.repository_owner == 'pytorch' + needs: get-label-type + uses: ./.github/workflows/_lint.yml + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + # Regenerate workflows + .github/scripts/generate_ci_workflows.py + + RC=0 + # Assert that regenerating the workflows didn't change them + if ! .github/scripts/report_git_status.sh .github/workflows; then + echo + echo 'As shown by the above diff, the committed .github/workflows' + echo 'are not up to date according to .github/templates.' + echo 'Please run this command, commit, and push again to your PR:' + echo + echo ' .github/scripts/generate_ci_workflows.py' + echo + echo 'If running that command does nothing, you may need to rebase' + echo 'onto a more recent commit from the PyTorch main branch.' + RC=1 + fi + + # Check that jobs will be cancelled + .github/scripts/ensure_actions_will_cancel.py + + exit $RC + + toc: + if: github.repository_owner == 'pytorch' + needs: get-label-type + uses: ./.github/workflows/_lint.yml + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + # Regenerate ToCs and check that they didn't change + set -eu + + export PATH=~/.npm-global/bin:"$PATH" + for FILE in $(git grep -Il '' -- '**.md'); do + markdown-toc --bullets='-' -i "$FILE" + done + + if ! .github/scripts/report_git_status.sh .; then + echo + echo 'As shown by the above diff, the table of contents in one or' + echo 'more Markdown files is not up to date with the file contents.' + echo 'You can either apply that Git diff directly to correct the' + echo 'table of contents, or if you have npm installed, you can' + echo 'install the npm package markdown-toc and run the following' + # shellcheck disable=SC2016 + echo 'command (replacing $FILE with the filename for which you want' + echo 'to regenerate the table of contents):' + echo + # shellcheck disable=SC2016 + echo " markdown-toc --bullets='-' -i \"\$FILE\"" + false + fi + + test-tools: + name: Test tools + if: ${{ github.repository == 'pytorch/pytorch' }} + needs: get-label-type + uses: ./.github/workflows/_lint.yml + with: + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-67eb930 + script: | + # Test tools + PYTHONPATH=$(pwd) pytest tools/stats + PYTHONPATH=$(pwd) pytest tools/test -o "python_files=test*.py" + PYTHONPATH=$(pwd) pytest .github/scripts -o "python_files=test*.py" + + test_run_test: + name: Test `run_test.py` is usable without boto3 + if: ${{ github.repository == 'pytorch/pytorch' }} + runs-on: linux.24_04.4x + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + submodules: false + fetch-depth: 1 + - name: Setup Python 3.10 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.10' + architecture: x64 + cache: pip + - name: Install dependencies + run: | + python3 -m pip install --upgrade pip + pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.3.* fbscribelogger==0.1.* numpy==1.24.* + pip install torch --pre --index-url https://download.pytorch.org/whl/nightly/cpu/ + - name: Run run_test.py (nonretryable) + run: | + # Run test_vulkan, which is a fast noop on Linux + python3 test/run_test.py --include test_vulkan --verbose + + test_collect_env: + if: ${{ github.repository == 'pytorch/pytorch' }} + name: Test collect_env + runs-on: ${{ matrix.runner }} + strategy: + matrix: + include: + - test_type: with_torch + runner: linux.24_04.4x + - test_type: without_torch + runner: linux.24_04.4x + - test_type: older_python_version + runner: linux.24_04.4x + steps: + # [see note: pytorch repo ref] + # deep clone (fetch-depth 0) required, to allow us to use git log + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + submodules: false + fetch-depth: 1 + - name: Get min python version + id: get-min-python-version + if: matrix.test_type == 'older_python_version' + run: | + set -eou pipefail + # Generate PyTorch version to use + echo "MIN_PYTHON_VERSION=$(python3 .github/scripts/get_ci_variable.py --min-python-version)" >> "${GITHUB_OUTPUT}" + - name: Setup Old Python version + if: matrix.test_type == 'older_python_version' + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: 3.8 + architecture: x64 + check-latest: false + cache: pip + cache-dependency-path: | + **/requirements-build.txt + **/requirements.txt + - name: Setup Min Python version + if: matrix.test_type != 'older_python_version' + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ steps.get-min-python-version.outputs.MIN_PYTHON_VERSION }} + architecture: x64 + check-latest: false + cache: pip + cache-dependency-path: | + **/requirements-build.txt + **/requirements.txt + - name: Install torch + if: matrix.test_type == 'with_torch' + run: | + pip install -r requirements.txt + # Doesn't really matter what torch version, we just need ANY torch installed + pip install 'torch==2.*' + - name: Run collect_env.py (nonretryable) + run: | + # All we need to see is that it passes + python3 torch/utils/collect_env.py + + link-check: + if: github.repository_owner == 'pytorch' + needs: get-label-type + name: Link checks + uses: ./.github/workflows/_link_check_osdc.yml + with: + runner: ${{ needs.get-label-type.outputs.label-type }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + + doc-redirects-check: + name: doc-redirects-check + runs-on: linux.24_04.4x + if: github.event_name == 'pull_request' && github.repository_owner == 'pytorch' + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + submodules: false + fetch-depth: 0 + + - name: Doc redirects check (nonretryable) + env: + BASE_REF: ${{ github.event.pull_request.base.ref }} + run: | + set -euo pipefail + + # Run the check with auto-fix to generate suggestions + python3 .github/scripts/check_doc_redirects.py \ + --base-ref "origin/${BASE_REF}" \ + --auto-fix + + # If redirects.py was modified, show the diff and fail + if ! git diff --quiet docs/source/redirects.py 2>/dev/null; then + echo "" + echo "📋 The following redirects were auto-generated:" + echo "" + git diff docs/source/redirects.py + echo "" + echo "Please add these changes to your PR by running locally:" + echo "" + echo " python3 .github/scripts/check_doc_redirects.py --base-ref origin/main --auto-fix" + echo " git add docs/source/redirects.py" + echo " git commit --amend --no-edit # or create a new commit" + echo "" + exit 1 + fi + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' && github.run_id }} + cancel-in-progress: true diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index c8b4720dfa88c..33006c7724258 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -60,9 +60,6 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::ComplexDouble: dtype.code = DLDataTypeCode::kDLComplex; break; - case ScalarType::BComplex32: - TORCH_CHECK_BUFFER(false, "BComplex32 type is not supported by dlpack"); - break; case ScalarType::BFloat16: dtype.code = DLDataTypeCode::kDLBfloat; break; diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 1c2ebad874997..870f7172d1622 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -306,23 +306,10 @@ TORCH_API void record_kernel_function_dtype(std::string name); AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) -#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ - AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) - #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__)) -#define AT_DISPATCH_COMPLEX_TYPES_AND2( \ - SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_DISPATCH_CASE_COMPLEX_TYPES_AND2( \ - SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) - #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index bbed2773318a2..4d12942eb0449 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -53,9 +53,6 @@ inline void raise_warning_for_complex_half(ScalarType dtype) { if (dtype == kComplexHalf) { TORCH_WARN_ONCE( "ComplexHalf support is experimental and many operators don't support it yet."); - } else if (dtype == kBComplex32) { - TORCH_WARN_ONCE( - "BComplex32 support is experimental and many operators don't support it yet."); } } diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 123d87b304148..c372ae0ad339f 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -105,12 +105,10 @@ static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIn Tensor FunctionalInverses::_fw_primal_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t level) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported."); - return Tensor(); } Tensor FunctionalInverses::_make_dual_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor& tangent, int64_t level) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported."); - return Tensor(); } Tensor FunctionalInverses::view_as_real_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { @@ -301,7 +299,6 @@ Tensor FunctionalInverses::transpose_int_inverse(const Tensor& base, const Tenso Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& nested_sizes, const Tensor& nested_strides, const Tensor& storage_offsets) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx, const std::optional& min_seqlen, const std::optional& max_seqlen) { @@ -342,47 +339,38 @@ Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& m Tensor FunctionalInverses::_indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::_sparse_broadcast_to_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef size) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::col_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::ccol_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::row_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::unbind_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int64_t dim) { diff --git a/aten/src/ATen/OpMathType.h b/aten/src/ATen/OpMathType.h index 9c8d854ad7294..613d4983c74dd 100644 --- a/aten/src/ATen/OpMathType.h +++ b/aten/src/ATen/OpMathType.h @@ -48,10 +48,6 @@ template <> struct OpMathType> { using type = c10::complex; }; -template <> -struct OpMathType> { - using type = c10::complex; -}; template using opmath_type = typename OpMathType::type; diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 7db13054fddb4..080bb5011cd3f 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -39,7 +39,7 @@ Tensor& scalar_fill(Tensor& self, const Scalar& value) { AT_DISPATCH_V2( self.scalar_type(), "fill_out", AT_WRAP([&]() { fill_inplace(self, value); - }), kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); return self; } diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h index fd2854374b422..ceb860b117653 100644 --- a/aten/src/ATen/SparseCsrTensorUtils.h +++ b/aten/src/ATen/SparseCsrTensorUtils.h @@ -141,8 +141,8 @@ AT_DISPATCH_SWITCH( \ TYPE, \ NAME, \ - AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \ - kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, __VA_ARGS__)) + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__)) namespace at::sparse_csr { diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 682883b32c187..4a1c1525157f3 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -218,8 +218,9 @@ std::string get_cxx_flags() { "Buck does not populate the `CXX_FLAGS` field of Caffe2 build options. " "As a result, `get_cxx_flags` is OSS only." ); - #endif + #else return caffe2::GetBuildOptions().at("CXX_FLAGS"); + #endif } } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index b9e66fea5ebdb..748eecbc1572a 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1956,7 +1956,6 @@ case ScalingType::TensorWise: default: TORCH_CHECK(false); - return -1; } } diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index 4f70f3fbd2353..fba4f855a29b0 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -25,9 +25,6 @@ template<> inline cudaDataType getCudaDataType() { template<> inline cudaDataType getCudaDataType>() { return CUDA_C_16F; } -template<> inline cudaDataType getCudaDataType>() { - return CUDA_C_16BF; -} template<> inline cudaDataType getCudaDataType>() { return CUDA_C_32F; } diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 42ee906065173..29affa2d21ff1 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -43,11 +43,6 @@ constexpr hipDataType HipDataTypeFor() { return HIP_R_16BF; } -template <> -constexpr hipDataType HipDataTypeFor>() { - return HIP_C_16BF; -} - template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_64F; diff --git a/aten/src/ATen/functorch/BatchRulesHelper.h b/aten/src/ATen/functorch/BatchRulesHelper.h index 0d2f075d0c540..f4583ac32a4a0 100644 --- a/aten/src/ATen/functorch/BatchRulesHelper.h +++ b/aten/src/ATen/functorch/BatchRulesHelper.h @@ -141,6 +141,8 @@ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::S auto arguments = torch::jit::pop(*stack, num_arguments); std::vector>> tensor_inputs; std::vector tensor_pos; + tensor_inputs.reserve(num_arguments); + tensor_pos.reserve(num_arguments); for (const auto idx : c10::irange(0, num_arguments)) { const auto& ivalue = arguments[idx]; if (ivalue.isTensor()) { diff --git a/aten/src/ATen/functorch/BatchedFallback.cpp b/aten/src/ATen/functorch/BatchedFallback.cpp index b479639f1c1a5..ee7b9b69eafb6 100644 --- a/aten/src/ATen/functorch/BatchedFallback.cpp +++ b/aten/src/ATen/functorch/BatchedFallback.cpp @@ -412,10 +412,8 @@ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::ji return; } - if (isInplaceOp(schema)) { - TORCH_INTERNAL_ASSERT(false, "vmap fallback not supported for in-place ops on nested tensors"); - return; - } + TORCH_INTERNAL_ASSERT(!isInplaceOp(schema), "vmap fallback not supported for in-place ops on nested tensors"); + TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(), "Nested batching rule not implemented for ", schema.operator_name(), "; ", "the fallback path doesn't work on out= or view ops."); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index 895770fc69921..72af353064661 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -156,7 +156,6 @@ c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed"); - return nullptr; } c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( @@ -164,7 +163,6 @@ c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed"); - return nullptr; } void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp index 1df4c8938183a..e02f20b102bc7 100644 --- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp @@ -709,8 +709,10 @@ Tensor nested_cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { // Do a cat for each set of zipped unbound components const auto num_components = unbound.front().size(); std::vector outputs; + outputs.reserve(num_components); for (auto i : c10::irange(num_components)) { std::vector arg_list; + arg_list.reserve(unbound.size()); for (auto j : c10::irange(unbound.size())) { arg_list.push_back(unbound[j][i]); } diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index c7de276f5f88f..164e709536135 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -1201,7 +1201,6 @@ struct Brgemm : public KernelCache { case ScalarType::Float8_e5m2: return f8_support; default: return false; } - return false; } }; @@ -1261,7 +1260,6 @@ struct Pack : public KernelCache { case ScalarType::Float8_e5m2: return fp8_pack; default: return false; } - return false; } }; #endif diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 94e37647a8a5f..51d83ba16779e 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -1774,10 +1774,6 @@ std::tuple convolution_backward_overrideable( IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups, std::array output_mask) { TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); - return std::tuple( - at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT), - at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT), - at::empty({})); } static Tensor subvariable(const Tensor& var, int64_t dim, int64_t groups, int64_t g) { diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 2c586a24ae046..0b3ffda30577f 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -60,14 +60,13 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) { #if !defined(C10_MOBILE) #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_V2( \ - TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, \ + TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \ AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \ - kComplexHalf, kBComplex32, kHalf, \ - kBool, kBFloat16, TYPE, NAME, \ - __VA_ARGS__) + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ + kComplexHalf, kHalf, kBool, kBFloat16, \ + TYPE, NAME, __VA_ARGS__) #endif // special case copy where tensor is contiguous and src is a transposed matrix diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 090a028786c72..cff843e0ee5c8 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -214,7 +214,6 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( default: TORCH_INTERNAL_ASSERT(false, "An unexpected device type was provided ", device_type); - return ErrorType::DeviceNotSupported; } } @@ -268,7 +267,6 @@ void* DispatchStubImpl::get_call_ptr( case ErrorType::MissingDeviceKernel: TORCH_INTERNAL_ASSERT( false, "DispatchStub: missing kernel for ", device_type); - return nullptr; case ErrorType::DeviceNotSupported: TORCH_CHECK(false, "DispatchStub: unsupported device type", device_type); } diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index c221394cc2a1c..f0dce20a6eff4 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -217,9 +217,8 @@ inline std::vector convert_tensor_to_scalar_list( "Expected packed scalar Tensor to be of dimension 1. Got ", scalarList_.dim(), " instead."); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16, diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 257863573d3a8..f608735abfa86 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -351,6 +351,7 @@ inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { const std::vector a = axes.vec(); const int64_t ndim = self.ndimension(); std::vector perm; + perm.reserve(static_cast(std::max(0, ndim))); for (const auto i : c10::irange(ndim)) { auto it = std::find(a.begin(), a.end(), i); diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index ef9e63d9fd358..ec4ce8d8550f4 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -534,9 +534,9 @@ Tensor ctc_loss_impl(const Tensor& log_probs_, const Tensor& targets, LengthsTyp target_lengths, BLANK, zero_infinity)); - } - if (zero_infinity) { - res = at::where(res.isinf(), at::zeros({}, res.options()), res); + if (zero_infinity) { + res = at::where(res == Scalar(std::numeric_limits::infinity()), at::zeros({}, res.options()), res); + } } if (reduction == at::Reduction::Mean) { auto target_lengths_t = get_clamped_target_length(target_lengths, res.options()); diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 1b465790d306c..80c466a6b5815 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -104,8 +104,9 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) { (input.scalar_type() == kHalf && !at::GradMode::is_enabled() && mkldnn_fp16_device_check())) && input.numel() != 0; -#endif +#else return false; +#endif } bool use_cudnn(const Tensor& t) { diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 4a8721db6fd44..dea7ecc7118ac 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -30,7 +30,7 @@ Scalar item(const Tensor& self) { } } -#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) +#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) #if !defined(C10_MOBILE) #define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES), AT_EXPAND(AT_FLOAT8_TYPES) #else diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 234bf81872d5c..ecb67c9ef3799 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -2844,9 +2844,8 @@ Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims) { const auto num_threads = at::get_num_threads(); DimVector thread_count_nonzero(num_threads); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, - kBComplex32, kHalf, kBFloat16, kBool, @@ -2899,9 +2898,8 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { DimVector thread_count_nonzero(num_threads + 1); // Pass 1: Count nonzero element per-thread - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, - kBComplex32, kHalf, kBFloat16, kBool, @@ -2937,9 +2935,8 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { auto out_accessor = result.accessor(); // Pass 2: Write indexes - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, - kBComplex32, kHalf, kBFloat16, kBool, diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 909381ce5ed30..697150b0fbb3a 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -652,7 +652,6 @@ Tensor to_dense_backward( default: TORCH_CHECK( false, "to_dense_backward: Unsupported input layout: ", input_layout); - return Tensor{}; } } @@ -1399,7 +1398,6 @@ Tensor dense_to_sparse_with_mask( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor dense_to_sparse_csr( @@ -1482,7 +1480,6 @@ Tensor dense_to_sparse( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) { @@ -1766,7 +1763,6 @@ Tensor sparse_compressed_to_sparse_csr( false, "sparse_compressed_to_sparse_csr: expected SparseCsr or SparseCsc layout but got ", self.layout()); - return Tensor{}; } Tensor sparse_compressed_to_sparse_csc( @@ -1787,7 +1783,6 @@ Tensor sparse_compressed_to_sparse_csc( false, "sparse_compressed_to_sparse_csc: expected SparseCsr or SparseCsc layout but got ", self.layout()); - return Tensor{}; } Tensor coo_to_sparse_csr( @@ -2222,7 +2217,6 @@ Tensor sparse_compressed_to_sparse_bsr( false, "sparse_compressed_to_sparse_bsr: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout()); - return Tensor{}; } Tensor sparse_compressed_to_sparse_bsc( @@ -2260,7 +2254,6 @@ Tensor sparse_compressed_to_sparse_bsc( false, "sparse_compressed_to_sparse_bsc: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout()); - return Tensor{}; } Tensor sparse_coo_to_sparse(const Tensor& self, const int64_t sparse_dim) { @@ -2273,7 +2266,6 @@ Tensor sparse_coo_to_sparse(const Tensor& self, const int64_t sparse_dim) { " to ", kSparse, " conversion not supported"); - return Tensor{}; } Tensor sparse_compressed_to_sparse( @@ -2377,7 +2369,6 @@ Tensor sparse_compressed_to_sparse( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor sparse_coo_to_sparse( @@ -2414,7 +2405,6 @@ Tensor sparse_coo_to_sparse( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor to_sparse(const Tensor& self, const int64_t sparse_dim) { diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index e663dc3859d1f..2d0fb908dc726 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -128,8 +128,7 @@ inline Tensor& fill_empty_deterministic_(Tensor& tensor) { AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf, - kComplexHalf, - kBComplex32); + kComplexHalf); } else { AT_DISPATCH_V2( tensor.scalar_type(), diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index 90dbf97075093..a1dbde708157b 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -137,7 +137,6 @@ Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_inverse(const a TORCH_INTERNAL_ASSERT(false, "Attempted to call _test_autograd_multiple_dispatch_view_inverse() during the functionalization pass. ", "This function is for testing only and should never be called."); - return Tensor(); } } // namespace at::functionalization diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index 12841ad8e7391..8ecf4a2324074 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -29,7 +29,6 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( false, "Sparse quantized dynamic linear with fused relu is not yet " "supported on qnnpack backend."); - return at::Tensor(); } template <> diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index f9fd934297cd8..24ab7e93b7202 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -626,7 +626,7 @@ void eq_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { AT_DISPATCH_V2(iter.common_dtype(), "eq_cpu", AT_WRAP([&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a == b; }); - }), kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); + }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } else { AT_DISPATCH_V2(iter.common_dtype(), "eq_cpu", AT_WRAP([&]() { cpu_kernel_vec( @@ -636,7 +636,7 @@ void eq_kernel(TensorIteratorBase& iter) { }, [](Vectorized a, Vectorized b) -> Vectorized { return a.eq(b); }); - }), kComplexHalf, kHalf, kBComplex32, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); + }), kComplexHalf, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } } @@ -645,7 +645,7 @@ void ne_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { AT_DISPATCH_V2(iter.common_dtype(), "ne_cpu", AT_WRAP([&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a != b; }); - }), kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); + }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } else { AT_DISPATCH_V2(iter.common_dtype(), "ne_cpu", AT_WRAP([&]() { cpu_kernel_vec( @@ -655,7 +655,7 @@ void ne_kernel(TensorIteratorBase& iter) { }, [](Vectorized a, Vectorized b) -> Vectorized { return a.ne(b); }); - }), kComplexHalf, kHalf, kBComplex32, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); + }), kComplexHalf, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } } diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index ed0bc9222cfb5..80708e548b196 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -199,17 +199,17 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne #if !defined(C10_MOBILE) #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ - kComplexHalf, kBComplex32, kHalf, kBool, \ + kComplexHalf, kHalf, kBool, \ kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \ AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ - kBool, kHalf, kBComplex32, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \ + kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \ AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ - ScalarType::ComplexHalf, ScalarType::BComplex32, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \ + ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \ TYPE, NAME, __VA_ARGS__) #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ @@ -235,8 +235,6 @@ void direct_copy_kernel(TensorIteratorBase &iter) { }); } else if (dtype == ScalarType::ComplexHalf) { cpu_kernel(iter, [=](c10::complex a) -> c10::complex { return a; }); - } else if (dtype == ScalarType::BComplex32) { - cpu_kernel(iter, [=](c10::complex a) -> c10::complex { return a; }); } else if (dtype == ScalarType::Float4_e2m1fn_x2) { cpu_kernel(iter, [=](Float4_e2m1fn_x2 a) -> Float4_e2m1fn_x2 { return a; }); } else if (isBitsType(dtype)) { diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index ca2d5b1c5080f..822cdad31787c 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -34,7 +34,6 @@ void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); @@ -202,7 +201,6 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef kFloat8_e4m3fnuz, kFloat8_e5m2fnuz, kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); @@ -214,7 +212,7 @@ void index_fill_kernel( int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf, kBComplex32, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf, iter.dtype(), "index_fill_cpu", [&] { auto fill_val = source.to(); auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) { @@ -273,7 +271,7 @@ void index_copy_kernel( int64_t dim, int64_t self_dim_size, int64_t self_dim_stride) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf, kBComplex32, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf, iter.dtype(), "index_copy_cpu", [&] { auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) { auto* self_data_bytes = data[0]; @@ -348,7 +346,7 @@ void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) { } void masked_fill_kernel(TensorIterator& iter, const Scalar& value) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(kComplexHalf, kBComplex32, kBool, kBFloat16, kHalf, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.dtype(), "masked_fill", [&] { scalar_t scalar_val = value.to(); auto mask_dtype = iter.input_dtype(0); @@ -435,7 +433,6 @@ void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) { AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); @@ -479,7 +476,6 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) { AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 682a8a2b8eff6..39fe91ccc06e4 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -24,54 +24,21 @@ namespace { using scale_t = std::vector>; -// Naming conventions used in this file: -// -// - "non_separable": All spatial dimensions are interpolated in a single -// TensorIterator pass, using the recursive InterpolateNonSeparable struct. -// Non-separable is always multi-dimensional (Nd). -// -// - "separable": The multi-dimensional interpolation is decomposed into a -// sequence of 1d passes (one per spatial dimension). The entry point is Nd -// (it loops over dims), and inner functions are 1d. -// -// - "1d": Refers to processing a single spatial dimension within the -// separable approach. Not to be confused with 1d interpolation (e.g. -// linear); it means one dimension of a multi-dimensional separable -// decomposition. -// -// - "Nd": Templated on out_ndims (1, 2, or 3) for the separable approach. - - -// ---- Non-separable interpolation ---- -// -// Used by: nearest, linear, bilinear (float), cubic (float), trilinear. -// Processes all spatial dims in a single TensorIterator pass via the recursive -// InterpolateNonSeparable struct. -// -// Call chain: -// upsample_non_separable_Nd_kernel_impl -// -> upsample_non_separable -// -> basic_loop_non_separable -// -> interpolate_non_separable -// -> InterpolateNonSeparable (recursive struct) -// -// Helper structs and methods for upsample_non_separable -// // Interpolation structure to compute output value in n-dimensional case. // - recursively compute interpolated output for each dimension // - we rely a lot on compiler's code optimization such that implemented operations // can be automatically factorized and vectorized using SSE and AVX2 template -struct InterpolateNonSeparable { +struct Interpolate { static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) { index_t ids = *(index_t*)&data[0][i * strides[0]]; opmath_t wts = *(scalar_t*)&data[1][i * strides[1]]; - opmath_t t = InterpolateNonSeparable::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i); + opmath_t t = Interpolate::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i); opmath_t output = t * wts; for (const auto j : c10::irange(1, interp_size)) { ids = *(index_t*)&data[2 * j + 0][i * strides[2 * j + 0]]; wts = *(scalar_t*)&data[2 * j + 1][i * strides[2 * j + 1]]; - t = InterpolateNonSeparable::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i); + t = Interpolate::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i); output += t * wts; } return output; @@ -79,7 +46,7 @@ struct InterpolateNonSeparable { }; template -struct InterpolateNonSeparable<1, scalar_t, opmath_t, index_t, interp_size> { +struct Interpolate<1, scalar_t, opmath_t, index_t, interp_size> { static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) { index_t ids = *(index_t*)&data[0][i * strides[0]]; opmath_t wts = *(scalar_t*)&data[1][i * strides[1]]; @@ -96,15 +63,15 @@ struct InterpolateNonSeparable<1, scalar_t, opmath_t, index_t, interp_size> { }; template -struct InterpolateNonSeparable { +struct Interpolate { static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) { index_t ids = *(index_t*)&data[0][i * strides[0]]; - return InterpolateNonSeparable::eval(src + ids, &data[2], &strides[2], i); + return Interpolate::eval(src + ids, &data[2], &strides[2], i); } }; template -struct InterpolateNonSeparable<1, scalar_t, opmath_t, index_t, 1> { +struct Interpolate<1, scalar_t, opmath_t, index_t, 1> { static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) { index_t ids = *(index_t*)&data[0][i * strides[0]]; return *(scalar_t *)&src[ids]; @@ -114,25 +81,25 @@ struct InterpolateNonSeparable<1, scalar_t, opmath_t, index_t, 1> { // There is an unexpected 2x slowdown for upsample_trilinear3d channels_first // for both 1 and 6 threads. We have to specialize this case as below: // Once the issue is fixed we can keep generic implementation and remove: -// struct InterpolateNonSeparable and -// struct InterpolateNonSeparable<1, scalar_t, index_t, 2> +// struct Interpolate and +// struct Interpolate<1, scalar_t, index_t, 2> template -struct InterpolateNonSeparable { +struct Interpolate { static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) { index_t i0 = *(index_t*)&data[0][i * strides[0]]; index_t i1 = *(index_t*)&data[2][i * strides[2]]; opmath_t w0 = *(scalar_t *)&data[1][i * strides[1]]; opmath_t w1 = *(scalar_t *)&data[3][i * strides[3]]; - opmath_t t0 = InterpolateNonSeparable::eval(src + i0, &data[4], &strides[4], i); - opmath_t t1 = InterpolateNonSeparable::eval(src + i1, &data[4], &strides[4], i); + opmath_t t0 = Interpolate::eval(src + i0, &data[4], &strides[4], i); + opmath_t t1 = Interpolate::eval(src + i1, &data[4], &strides[4], i); return t0 * w0 + t1 * w1; } }; template -struct InterpolateNonSeparable<1, scalar_t, opmath_t, index_t, 2> { +struct Interpolate<1, scalar_t, opmath_t, index_t, 2> { static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) { index_t i0 = *(index_t*)&data[0][i * strides[0]]; index_t i1 = *(index_t*)&data[2][i * strides[2]]; @@ -145,28 +112,13 @@ struct InterpolateNonSeparable<1, scalar_t, opmath_t, index_t, 2> { }; template -inline scalar_t interpolate_non_separable(char* src, char** data, const int64_t* strides, int64_t i) { +inline scalar_t interpolate(char* src, char** data, const int64_t* strides, int64_t i) { using opmath_t = at::opmath_type; - return InterpolateNonSeparable::eval(src, data, strides, i); + return Interpolate::eval(src, data, strides, i); } -// ---- Separable interpolation ---- -// -// Used by: bilinear/bicubic with antialias=True, and bilinear/bicubic uint8 -// (as fallback when AVX isn't supported). -// Processes one spatial dimension at a time. The outer loop over dimensions -// is in upsample_separable_Nd_kernel_impl. -// -// Call chain: -// upsample_separable_Nd_kernel_impl (loops over dims) -// -> upsample_separable_1d -// -> basic_loop_separable_1d_horizontal (for last spatial dim, i.e. W) -// -> interpolate_separable_1d -// -> basic_loop_separable_1d_vertical (for other spatial dims, e.g. H, D) -// -> interpolate_separable_1d_zero_strides - template -inline scalar_t interpolate_separable_1d_zero_strides( +inline scalar_t interpolate_aa_single_dim_zero_strides( char* src, char** data, const index_t ids_stride) { @@ -190,7 +142,7 @@ inline scalar_t interpolate_separable_1d_zero_strides( } template -inline scalar_t interpolate_separable_1d( +inline scalar_t interpolate_aa_single_dim( char* src, char** data, const int64_t* strides, @@ -263,7 +215,7 @@ inline bool is_contiguous_stride(const int64_t* strides) { // strides=(0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0) // // Using these methods we can hint the compiler to factorize constant indices and weights -// in upsample_non_separable +// in cpu_upsample_linear method template struct CheckAlmostAllZeroStrides { static inline bool eval(const int64_t* strides) { @@ -295,17 +247,17 @@ inline bool check_almost_all_zero_stride(const int64_t* strides) { // Helper method to compute interpolation for nearest, linear, cubic modes template -inline void basic_loop_non_separable(char** data, const int64_t* strides, int64_t n) { +inline void basic_loop(char** data, const int64_t* strides, int64_t n) { char* dst = data[0]; char* src = data[1]; for (const auto i : c10::irange(n)) { - *(scalar_t*)&dst[i * strides[0]] = interpolate_non_separable( + *(scalar_t*)&dst[i * strides[0]] = interpolate( src + i * strides[1], &data[2], &strides[2], i); } } template -inline void basic_loop_separable_1d_vertical( +inline void basic_loop_aa_vertical( char** data, const int64_t* strides, int64_t n, @@ -317,13 +269,13 @@ inline void basic_loop_separable_1d_vertical( for (const auto i : c10::irange(n)) { *(scalar_t*)&dst[i * strides[0]] = - interpolate_separable_1d_zero_strides( + interpolate_aa_single_dim_zero_strides( src + i * strides[1], &data[2], ids_stride); } } template <> -inline void basic_loop_separable_1d_vertical( +inline void basic_loop_aa_vertical( char** data, const int64_t* strides, int64_t n, @@ -361,7 +313,7 @@ inline void basic_loop_separable_1d_vertical( } template -inline void basic_loop_separable_1d_horizontal( +inline void basic_loop_aa_horizontal( char** data, const int64_t* strides, int64_t n, @@ -374,20 +326,20 @@ inline void basic_loop_separable_1d_horizontal( if (strides[1] == 0) { for (const auto i : c10::irange(n)) { *(scalar_t*)&dst[i * strides[0]] = - interpolate_separable_1d( + interpolate_aa_single_dim( src, &data[2], &strides[2], i, ids_stride); } } else { for (const auto i : c10::irange(n)) { *(scalar_t*)&dst[i * strides[0]] = - interpolate_separable_1d( + interpolate_aa_single_dim( src + i * strides[1], &data[2], &strides[2], i, ids_stride); } } } template <> -inline void basic_loop_separable_1d_horizontal( +inline void basic_loop_aa_horizontal( char** data, const int64_t* strides, int64_t n, @@ -438,10 +390,10 @@ inline void basic_loop_separable_1d_horizontal( // output_DN[a] = interpolate(input_DN[a], w_DN[a], input_DN[a+1], w_DN[a+1], ...) // and i - dimension index and a - linear index for spatial coordinates // -// The recursive call is implemented with the InterpolateNonSeparable struct using template for +// The recursive call is implemented with InterpLinear struct using template for // the loop unrolling on compile time. template -void upsample_non_separable(at::TensorIterator& iter) +void cpu_upsample_generic(at::TensorIterator& iter) { auto loop = [&](char** data, const int64_t* strides, int64_t n) { // special-cases to let the compiler apply compile-time input-specific optimizations @@ -449,21 +401,21 @@ void upsample_non_separable(at::TensorIterator& iter) // NOLINTNEXTLINE(bugprone-branch-clone) check_almost_all_zero_stride(&strides[2]))) { // contiguous channels-first case - basic_loop_non_separable(data, strides, n); + basic_loop(data, strides, n); } else if ((strides[0] == sizeof(scalar_t) && (strides[1] == sizeof(scalar_t)) && check_almost_all_zero_stride(&strides[2]))) { // contiguous channels-last case - basic_loop_non_separable(data, strides, n); + basic_loop(data, strides, n); } else { // fallback - basic_loop_non_separable(data, strides, n); + basic_loop(data, strides, n); } }; iter.for_each(loop); } template -void upsample_nearest_channels_last( +void cpu_upsample_nearest_channels_last( const Tensor& output_, const Tensor& input_, const scale_type& scales) { @@ -568,7 +520,7 @@ inline VecType interpolate(const scalar_t* t, accscalar_t w, Args... a } template -void upsample_linear_channels_last( +void cpu_upsample_linear_channels_last( const Tensor& output_, const Tensor& input_, bool align_corners, @@ -728,7 +680,7 @@ void upsample_linear_channels_last( } } -// Helper structs to use with upsample_non_separable_Nd_kernel_impl +// Helper structs to use with upsample_generic_Nd_kernel_impl struct HelperInterpBase { static inline void init_indices_weights( @@ -794,7 +746,7 @@ struct HelperInterpBase { // for interpolation with antialiasing=false mode. It returns the maximal weights value. // This function is templated with scalar_t for type of scale and weights but is only used for // bilinear/bicubic modes on uint8 input and antialiasing=false (in this case scalar_t is double). - // For float input types we are using upsample_non_separable_Nd_kernel_impl and compute_indices_weights methods + // For float input types we are using upsample_generic_Nd_kernel_impl and compute_indices_weights methods template static inline scalar_t _compute_indices_min_size_weights( const int64_t i, const int64_t input_size, const scalar_t scale, @@ -954,7 +906,7 @@ struct HelperInterpBase { weights as double, but then convert them to int16 via some conversion logic detailed below. This allows us to compute all interpolation operation (sum of multiplications) as ints instead of floats. The result is converted back into - uint8 in basic_loop_separable_1d_horizontal (and vertical) + uint8 in basic_loop_aa_horizontal (and vertical) In essence the idea is to avoid a multiplication between a float (the weight) and an int (the pixel value) and instead run a multiplication between @@ -1425,7 +1377,7 @@ struct HelperInterpCubic : public HelperInterpBase { // - scale_type is template type for scales, typically std::optional // - template class F is one of the above structs to compute indices and weights template -void upsample_non_separable_Nd_kernel_impl( +void upsample_generic_Nd_kernel_impl( const Tensor& output, const Tensor& input, bool align_corners, @@ -1486,22 +1438,42 @@ void upsample_non_separable_Nd_kernel_impl( if (interp_size > 1) { // Nearest also supports uint8 tensor, so need to handle it separately AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, iter.dtype(), "upsample_non_separable", [&] { + kBFloat16, kHalf, iter.dtype(), "upsample_generic_Nd", [&] { // MSVC can not catch constexpr int interp_size here constexpr int mode = F::interp_size; - upsample_non_separable(iter); + cpu_upsample_generic(iter); }); } else { AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, - iter.dtype(), "upsample_non_separable", [&] { + iter.dtype(), "upsample_generic_Nd", [&] { constexpr int mode = F::interp_size; - upsample_non_separable(iter); + cpu_upsample_generic(iter); }); } } +template +void cpu_upsample_generic_aa(at::TensorIterator& iter, unsigned int weights_precision) { + + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + if constexpr (is_horizontal) { + + // Strides are : X 0 | 8 8 8 0 8 (Channels first) + // Strides are : X X | 0 0 0 0 0 (Channels last) + basic_loop_aa_horizontal(data, strides, n, weights_precision); + } else { + // Strides are : X Y | 0 0 0 0 0 (Channels first) + // Strides are : X X | 0 0 0 0 0 (Channels last) + // upsampling data between contiguous dimensions (aka vertical resampling) + basic_loop_aa_vertical(data, strides, n, weights_precision); + } + }; + + iter.for_each(loop); +} + template -void upsample_separable_1d( +void _separable_upsample_generic_Nd_kernel_impl_single_dim( const Tensor& output, const Tensor& input, int interp_dim, @@ -1562,19 +1534,8 @@ void upsample_separable_1d( auto iter = config.build(); AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::Byte, iter.dtype(), "upsample_separable_1d", [&] { - auto loop = [&](char** data, const int64_t* strides, int64_t n) { - if constexpr (is_horizontal) { - // Strides are : X 0 | 8 8 8 0 8 (Channels first) - // Strides are : X X | 0 0 0 0 0 (Channels last) - basic_loop_separable_1d_horizontal(data, strides, n, weights_precision); - } else { - // Strides are : X Y | 0 0 0 0 0 (Channels first) - // Strides are : X X | 0 0 0 0 0 (Channels last) - basic_loop_separable_1d_vertical(data, strides, n, weights_precision); - } - }; - iter.for_each(loop); + at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd_aa", [&] { + cpu_upsample_generic_aa(iter, weights_precision); }); } @@ -1583,7 +1544,7 @@ void upsample_separable_1d( // (dtype == uint8 and mode in ("bilinear", "bicubic")): this is used as // fallback in these settings when AVX isn't supported. template -void upsample_separable_Nd_kernel_impl( +void separable_upsample_generic_Nd_kernel_impl( const Tensor& output, const Tensor& input, bool align_corners, @@ -1602,29 +1563,29 @@ void upsample_separable_Nd_kernel_impl( at::Tensor temp_output, temp_input = input; int interp_dim = 0; - // Precompute the number of 1d resize ops + // Precompute the number of single dim resize method invocations // to avoid copying temporary buffer to output - int num_1d_ops = 0; + int num_single_dim_ops = 0; for (const auto i : c10::irange(out_ndims)) { interp_dim = 2 + out_ndims - 1 - i; if (output_shape[interp_dim] != input_shape[interp_dim]) { - num_1d_ops += 1; + num_single_dim_ops += 1; } } - // Horizontal resampling (last spatial dim, i.e. W) + // upsampling data within the contiguous dimension (aka horizontal resampling) interp_dim = 2 + out_ndims - 1; if (output_shape[interp_dim] != input_shape[interp_dim]) { - num_1d_ops -= 1; - if (num_1d_ops > 0) { + num_single_dim_ops -= 1; + if (num_single_dim_ops > 0) { temp_oshape[interp_dim] = output_shape[interp_dim]; temp_output = at::empty(temp_oshape, input.options()); } else { temp_output = output; } - upsample_separable_1d< + _separable_upsample_generic_Nd_kernel_impl_single_dim< out_ndims, scale_t, F, @@ -1633,20 +1594,20 @@ void upsample_separable_Nd_kernel_impl( temp_input = temp_output; } - // Vertical resampling (remaining spatial dims, e.g. H, D) + // upsampling data between contiguous dimensions (aka vertical resampling) for (const auto i : c10::irange(1, out_ndims)) { interp_dim = 2 + out_ndims - 1 - i; if (output_shape[interp_dim] != input_shape[interp_dim]) { - num_1d_ops -= 1; - if (num_1d_ops > 0) { + num_single_dim_ops -= 1; + if (num_single_dim_ops > 0) { temp_oshape[interp_dim] = output_shape[interp_dim]; temp_output = at::empty(temp_oshape, input.options()); } else { temp_output = output; } - upsample_separable_1d< + _separable_upsample_generic_Nd_kernel_impl_single_dim< out_ndims, scale_t, F, @@ -1661,7 +1622,7 @@ void upsample_nearest1d_kernel_impl( const Tensor& output, const Tensor& input, std::optional scales_w) { - upsample_non_separable_Nd_kernel_impl<1, scale_t, HelperInterpNearest>( + upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpNearest>( output, input, false, {scales_w}); } @@ -1669,28 +1630,26 @@ void _upsample_nearest_exact1d_kernel_impl( const Tensor& output, const Tensor& input, std::optional scales_w) { - upsample_non_separable_Nd_kernel_impl<1, scale_t, HelperInterpNearestExact>( + upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpNearestExact>( output, input, false, {scales_w}); } -int _use_channels_last_kernel_2d( +int _use_vectorized_kernel_cond_2d( const Tensor& output, const Tensor& input) { - // This condition is used to know whether we should dispatch to a - // channels-last-optimized kernel, or to the more general - // upsample_non_separable_Nd_kernel_impl(). For now, the channels-last kernels - // are only optimized for channels_last and when C >= 4 (shape = NCHW). - // For a very wide range of use-cases (typically image or mask resizing - // where we have C < 4), using upsample_non_separable_Nd_kernel_impl() is - // actually faster. On top of that, benchmarks showed that this also - // depends on the *output* size (output_H + output_W), for both - // upsampling and downsampling. The current 128 threshold was determined - // through benchmarks. + // This condition is used to know whether we should dispatch to a vectorized + // kernel, or to the more general upsample_generic_Nd_kernel_impl(). For now, + // the vectorized kernels are only optimized for channels_last and when C >= 4 + // (shape = NCHW). For a very wide range of use-cases (typically image or mask + // resizing where we have C < 4), using upsample_generic_Nd_kernel_impl() is + // actually faster. On top of that, benchmarks showed that this also depends on + // the *output* size (output_H + output_W), for both upsampling and + // downsampling. The current 128 threshold was determined through benchmarks. return ((input.is_contiguous(at::MemoryFormat::ChannelsLast)) && (input.size(1) > 3)) || ((output.size(-2) + output.size(-1)) <= 128); } -int _use_channels_last_kernel_3d( - // Similar to _use_channels_last_kernel_2d() but for 3d resampling (e.g. videos) +int _use_vectorized_kernel_cond_3d( + // Similar to _use_vectorized_kernel_cond_2d() but for 3d resampling (e.g. videos) // Note that unlike the 2d case, this is not subject to small output size // overhead - hence the absence of the 128 threshold in the condition. const Tensor& output, @@ -1704,13 +1663,13 @@ void upsample_nearest2d_kernel_impl( const Tensor& input, std::optional scales_h, std::optional scales_w) { - if (_use_channels_last_kernel_2d(output, input)) { + if (_use_vectorized_kernel_cond_2d(output, input)) { AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest2d_channels_last", [&] { - upsample_nearest_channels_last(output, input, {scales_h, scales_w}); + cpu_upsample_nearest_channels_last(output, input, {scales_h, scales_w}); }); } else { - upsample_non_separable_Nd_kernel_impl<2, scale_t, HelperInterpNearest>( + upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpNearest>( output, input, false, {scales_h, scales_w}); } } @@ -1720,12 +1679,12 @@ void _upsample_nearest_exact2d_kernel_impl( const Tensor& input, std::optional scales_h, std::optional scales_w) { - if (_use_channels_last_kernel_2d(output, input)) { + if (_use_vectorized_kernel_cond_2d(output, input)) { AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest2d_channels_last", [&] { - upsample_nearest_channels_last(output, input, {scales_h, scales_w}); + cpu_upsample_nearest_channels_last(output, input, {scales_h, scales_w}); }); } else { - upsample_non_separable_Nd_kernel_impl<2, scale_t, HelperInterpNearestExact>( + upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpNearestExact>( output, input, false, {scales_h, scales_w}); } } @@ -1736,13 +1695,13 @@ void upsample_nearest3d_kernel_impl( std::optional scales_d, std::optional scales_h, std::optional scales_w) { - if (_use_channels_last_kernel_3d(output, input)) { + if (_use_vectorized_kernel_cond_3d(output, input)) { AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest3d_channels_last", [&] { - upsample_nearest_channels_last(output, input, {scales_d, scales_h, scales_w}); + cpu_upsample_nearest_channels_last(output, input, {scales_d, scales_h, scales_w}); }); } else { - upsample_non_separable_Nd_kernel_impl<3, scale_t, HelperInterpNearest>( + upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpNearest>( output, input, false, {scales_d, scales_h, scales_w}); } } @@ -1753,12 +1712,12 @@ void _upsample_nearest_exact3d_kernel_impl( std::optional scales_d, std::optional scales_h, std::optional scales_w) { - if (_use_channels_last_kernel_3d(output, input)) { + if (_use_vectorized_kernel_cond_3d(output, input)) { AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest3d_channels_last", [&] { - upsample_nearest_channels_last(output, input, {scales_d, scales_h, scales_w}); + cpu_upsample_nearest_channels_last(output, input, {scales_d, scales_h, scales_w}); }); } else { - upsample_non_separable_Nd_kernel_impl<3, scale_t, HelperInterpNearestExact>( + upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpNearestExact>( output, input, false, {scales_d, scales_h, scales_w}); } } @@ -1768,7 +1727,7 @@ void upsample_linear1d_kernel_impl( const Tensor& input, bool align_corners, std::optional scales_w) { - upsample_non_separable_Nd_kernel_impl<1, scale_t, HelperInterpLinear>( + upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpLinear>( output, input, align_corners, {scales_w}); } @@ -1780,17 +1739,17 @@ void upsample_bilinear2d_kernel_impl_float( std::optional scales_h, std::optional scales_w) { - // See note above about _use_channels_last_kernel_2d(output, input). The extra cond is present + // See note above about _use_vectorized_kernel_cond_2d(output, input). The extra cond is present // because benchmarks showed that with only 1 thread, images (C == 3) were - // slightly faster with the channels-last kernel than with the generic one. + // slightly faster with the vectorized kernel than with the generic one. // That's not the case for masks though (C == 1), which strongly benefit from // using the generic kernel. - if ((_use_channels_last_kernel_2d(output, input)) || (at::get_num_threads() == 1 && input.size(1) == 3)) { + if ((_use_vectorized_kernel_cond_2d(output, input)) || (at::get_num_threads() == 1 && input.size(1) == 3)) { AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "upsample_bilinear2d_channels_last", [&] { - upsample_linear_channels_last(output, input, align_corners, {scales_h, scales_w}); + cpu_upsample_linear_channels_last(output, input, align_corners, {scales_h, scales_w}); }); } else { - upsample_non_separable_Nd_kernel_impl<2, scale_t, HelperInterpLinear>( + upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>( output, input, align_corners, {scales_h, scales_w}); } } @@ -1818,7 +1777,7 @@ void upsample_bilinear2d_kernel_impl( /*antialias=*/false); } #endif // CPU_CAPABILITY_AVX2 - return upsample_separable_Nd_kernel_impl<2, scale_t, HelperInterpLinear>( + return separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>( output, input, align_corners, {scales_h, scales_w}, /*antialias=*/false); } @@ -1849,7 +1808,7 @@ void upsample_bilinear2d_aa_kernel_impl( } #endif // CPU_CAPABILITY_AVX2 } - return upsample_separable_Nd_kernel_impl<2, scale_t, HelperInterpLinear>( + return separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>( output, input, align_corners, {scales_h, scales_w}, /*antialias=*/true); } @@ -1861,12 +1820,12 @@ void upsample_trilinear3d_kernel_impl( std::optional scales_d, std::optional scales_h, std::optional scales_w) { - if ((_use_channels_last_kernel_3d(output, input))) { + if ((_use_vectorized_kernel_cond_3d(output, input))) { AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "upsample_trilinear3d_channels_last", [&] { - upsample_linear_channels_last(output, input, align_corners, {scales_d, scales_h, scales_w}); + cpu_upsample_linear_channels_last(output, input, align_corners, {scales_d, scales_h, scales_w}); }); } else { - upsample_non_separable_Nd_kernel_impl<3, scale_t, HelperInterpLinear>( + upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpLinear>( output, input, align_corners, {scales_d, scales_h, scales_w}); } } @@ -1894,11 +1853,11 @@ void upsample_bicubic2d_kernel_impl( /*antialias=*/false); } #endif // CPU_CAPABILITY_AVX2 - return upsample_separable_Nd_kernel_impl<2, scale_t, HelperInterpCubic>( + return separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>( output, input, align_corners, {scales_h, scales_w}, /*antialias=*/false); } - return upsample_non_separable_Nd_kernel_impl<2, scale_t, HelperInterpCubic>( + return upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>( output, input, align_corners, {scales_h, scales_w}); } @@ -1926,7 +1885,7 @@ void upsample_bicubic2d_aa_kernel_impl( } #endif // CPU_CAPABILITY_AVX2 } - return upsample_separable_Nd_kernel_impl<2, scale_t, HelperInterpCubic>( + return separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>( output, input, align_corners, {scales_h, scales_w}, /*antialias=*/true); } @@ -1935,7 +1894,7 @@ template < typename scalar_t, typename scale_type, class F> -void upsample_separable_Nd_backward_aa( +void cpu_upsample_genNd_backward_aa( const Tensor& grad_input_, const Tensor& grad_output_, bool align_corners, @@ -2053,7 +2012,7 @@ void upsample_bilinear2d_aa_backward_kernel_impl( std::optional scales_w) { AT_DISPATCH_FLOATING_TYPES( grad_output.scalar_type(), "upsample_bilinear2d_aa_backward_cpu", [&] { - upsample_separable_Nd_backward_aa( + cpu_upsample_genNd_backward_aa( grad_input, grad_output, align_corners, {scales_h, scales_w}); }); } @@ -2066,7 +2025,7 @@ void upsample_bicubic2d_aa_backward_kernel_impl( std::optional scales_w) { AT_DISPATCH_FLOATING_TYPES( grad_output.scalar_type(), "upsample_bicubic2d_aa_backward_cpu", [&] { - upsample_separable_Nd_backward_aa( + cpu_upsample_genNd_backward_aa( grad_input, grad_output, align_corners, {scales_h, scales_w}); }); } diff --git a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h index 3c351b05dc92f..c1bf79dfa44e6 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h +++ b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -176,7 +176,7 @@ void ImagingResampleHorizontal( // // TODO: we may want to merge that into the fallback code (currently called - // basic_loop_separable_1d_horizontal) + // basic_loop_aa_horizontal) // Although this may not be needed if / when we port all this code to use // Vec.h since this would potentially give us another fall-back implem @@ -252,7 +252,7 @@ void ImagingResampleVertical( // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] // TODO: we may want to merge that into the fallback code (currently called - // basic_loop_separable_1d_vertical) + // basic_loop_aa_vertical) // Although this may not be needed if / when we port all this code to use // Vec.h since this would potentially give us another fall-back implem const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr()); @@ -289,7 +289,7 @@ void ImagingResampleVertical( // mode for uint8 dtype when C <= 4, with or without antialias. The // implem is based on PIL-SIMD. // Its equivalent implementation (fallback) for when AVX isn't supported or when -// C > 4 is upsample_separable_Nd_kernel_impl() There are a bunch of +// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of // future improvement that can be done: look for the TODOs in this file. // For details on how the weights are computed and how the multiplications are // run on int (instead of float weights), see diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index e427ff9a50da0..d87395bb480b3 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -108,8 +108,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa case Activation::GELU: return cuda::blas::GEMMAndBiasActivationEpilogue::GELU; default: - TORCH_CHECK(false); - return cuda::blas::GEMMAndBiasActivationEpilogue::None; + TORCH_CHECK(false, "Unknown activation epologue type"); } } @@ -228,9 +227,6 @@ static bool isInputCompliesAddmmCudaLt( mat2_sizes[0] > 1 && mat2_sizes[1] > 1 ) ); - - // no compliance by default - return false; } template diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 02c45bee9d98d..169a2ab92615f 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -31,7 +31,7 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); at::cuda::memcpy_and_sync(value.mutable_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBComplex32, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); return r; } diff --git a/aten/src/ATen/native/cuda/CompareEQKernel.cu b/aten/src/ATen/native/cuda/CompareEQKernel.cu index 24602c2956a27..442e484b9fa5c 100644 --- a/aten/src/ATen/native/cuda/CompareEQKernel.cu +++ b/aten/src/ATen/native/cuda/CompareEQKernel.cu @@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) { AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() { opmath_symmetric_gpu_kernel_with_scalars( iter, CompareEqFunctor(op)); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBComplex32, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } void eq_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 9094339cac83f..3a9abc0588648 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -21,8 +21,7 @@ #include #include -// TODO(NS): Investigate why FP8 conversion intrinsics end up being slower -#ifdef AT_USE_NV_CVT_INTRINSICS +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000 #include #endif @@ -69,25 +68,53 @@ void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { } #endif +template +struct ConvertToFloat8E4M3fnOp { + __device__ __forceinline__ Float8_e4m3fn operator()(SrcT value) const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + __nv_fp8_storage_t x; + if constexpr (std::is_same_v) { + x = __nv_cvt_float_to_fp8(value, __NV_SATFINITE, __NV_E4M3); + } else if constexpr (std::is_same_v) { + x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_SATFINITE, __NV_E4M3); + } else if constexpr (std::is_same_v) { + x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_SATFINITE, __NV_E4M3); + } else { + x = __nv_cvt_float_to_fp8(static_cast(value), __NV_SATFINITE, __NV_E4M3); + } + return Float8_e4m3fn(x, Float8_e4m3fn::from_bits()); +#else + return Float8_e4m3fn(value); +#endif + } +}; + +// e5m2 intrinsics are correct but slower; only used for float on Blackwell +// to work around the ptxas subnormal codegen bug. +struct ConvertFloatToFloat8E5M2Op { + __device__ __forceinline__ Float8_e5m2 operator()(float value) const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13020 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 + auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + } +}; + void float8_copy_kernel_cuda(TensorIteratorBase &iter) { ScalarType dtype = iter.dtype(0); ScalarType other_dtype = iter.dtype(1); if (dtype == kFloat8_e4m3fn) { switch (other_dtype) { case kFloat: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { - return Float8_e4m3fn(value); - }); + gpu_kernel_nocast(iter, ConvertToFloat8E4M3fnOp{}); break; case kHalf: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { - return Float8_e4m3fn(value); - }); + gpu_kernel_nocast(iter, ConvertToFloat8E4M3fnOp{}); break; case kBFloat16: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { - return Float8_e4m3fn(value); - }); + gpu_kernel_nocast(iter, ConvertToFloat8E4M3fnOp{}); break; default: gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; }); @@ -96,33 +123,16 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { } else if (dtype == kFloat8_e5m2) { switch (other_dtype) { case kFloat: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { -#ifdef AT_USE_NV_CVT_INTRINSICS - const auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2); - return Float8_e5m2(x, Float8_e5m2::from_bits()); -#else - return Float8_e5m2(value); -#endif - }); + gpu_kernel_nocast(iter, ConvertFloatToFloat8E5M2Op{}); break; case kHalf: gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { -#ifdef AT_USE_NV_CVT_INTRINSICS - const auto x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_NOSAT, __NV_E5M2); - return Float8_e5m2(x, Float8_e5m2::from_bits()); -#else return Float8_e5m2(value); -#endif }); break; case kBFloat16: gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { -#ifdef AT_USE_NV_CVT_INTRINSICS - const auto x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_NOSAT, __NV_E5M2); - return Float8_e5m2(x, Float8_e5m2::from_bits()); -#else return Float8_e5m2(value); -#endif }); break; default: @@ -238,7 +248,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_V2( dtype, "copy_", AT_WRAP([&] { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, kBComplex32, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } } diff --git a/aten/src/ATen/native/cuda/FillKernel.cu b/aten/src/ATen/native/cuda/FillKernel.cu index 1772892ac7c7b..266f0e49b8e5a 100644 --- a/aten/src/ATen/native/cuda/FillKernel.cu +++ b/aten/src/ATen/native/cuda/FillKernel.cu @@ -22,7 +22,7 @@ struct FillFunctor { void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) { AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() { gpu_kernel(iter, FillFunctor(value.to())); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBComplex32, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda) diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index e617c46d5cbce..8c9125879dc06 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -132,10 +132,10 @@ _mx8_mx8_bf16_grouped_mm_mslk( scale_b, offs.value(), out); + return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_MSLK"); #endif - return out; } // 2d-2d and 2d-3d cases @@ -197,10 +197,10 @@ _f8_f8_bf16_rowwise_grouped_mm_rocm( scale_b, offs, out); + return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_MSLK on ROCM") #endif - return out; } #endif // USE_ROCM @@ -291,11 +291,11 @@ _f4_f4_bf16_grouped_mm_mslk( out, combined_global_scale ); + + return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_MSLK, and only for CUDA") #endif - - return out; } void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 3a331825cfcfa..04b0756817d51 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -225,7 +225,6 @@ static void index_kernel( AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); @@ -237,9 +236,8 @@ static void index_fill_kernel( const int64_t self_dim_size, const int64_t self_dim_stride, const Scalar& source) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf, - kBComplex32, iter.dtype(), "index_fill_cuda", [&] { using dtype = OpaqueType; const auto fill_val = source.to(); @@ -256,9 +254,8 @@ static void index_copy_kernel( // See note [Writing Nondeterministic Operations] // Nondeterministic when index contains duplicate entries // this kernel will not be called when torch.use_deterministic_algorithms(True) - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf, - kBComplex32, iter.dtype(), "index_copy_cuda", [&] { using dtype = OpaqueType; index_copy_kernel_impl(iter, dim, self_dim_size, self_dim_stride); @@ -278,7 +275,6 @@ static void index_put_kernel(TensorIterator& iter, const IntArrayRef index_size, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); @@ -502,7 +498,6 @@ void flip_kernel(TensorIterator& iter, const bool quantized) { AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 366c38b97e09e..a052943c597ec 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -752,7 +752,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List(); gpu_kernel( iter, [value_] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t { diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index 60e5aa54aae0e..063db29d38279 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -95,7 +95,7 @@ void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f, bool check_cas if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { - gpu_kernel_nocast(sub_iter, f); + gpu_kernel_nocast(sub_iter, f, check_cast); } return; } diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 6f33092b99eea..2c28dc14a0c8a 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -205,8 +205,7 @@ bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingTyp case ScalingType::BlockWise128x128: return is_blockwise_128x128_scaling(t, scale); default: - TORCH_CHECK(false); - return false; + TORCH_CHECK(false, "Unknown scaling type"); } } diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 30d9f49168eb1..06a2805e89fca 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -311,7 +311,6 @@ struct cuda_scatter_gather_base_kernel { AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16); diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 13c42ed23cebb..374684e2016e9 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -609,7 +609,6 @@ TORCH_IMPL_FUNC(cat_out_cuda) }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16, @@ -643,7 +642,6 @@ TORCH_IMPL_FUNC(cat_out_cuda) }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, - kBComplex32, kHalf, kBool, kBFloat16, diff --git a/aten/src/ATen/native/cuda/UnarySignKernels.cu b/aten/src/ATen/native/cuda/UnarySignKernels.cu index 80f2805f0128e..2736aa33bc2f1 100644 --- a/aten/src/ATen/native/cuda/UnarySignKernels.cu +++ b/aten/src/ATen/native/cuda/UnarySignKernels.cu @@ -36,7 +36,7 @@ void neg_kernel_cuda(TensorIteratorBase& iter) { return -a; } ); // neg_string - AT_DISPATCH_COMPLEX_TYPES_AND2(kComplexHalf, kBComplex32, dtype, "neg_cuda", [&]() { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_cuda", [&]() { jitted_gpu_kernel< /*name=*/ neg_name, /*return_dtype=*/ scalar_t, @@ -44,7 +44,7 @@ void neg_kernel_cuda(TensorIteratorBase& iter) { /*arity=*/ 1>(iter, neg_string); }); #else - AT_DISPATCH_COMPLEX_TYPES_AND2(kComplexHalf, kBComplex32, dtype, "neg_cuda", [&]() { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "neg_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return -a; }); diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 3ad4242618d4a..74739b348db1f 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -981,7 +981,6 @@ int calc_thread_work_size( } else { return 4; } - return io_size; #else auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type); TORCH_INTERNAL_ASSERT(io_size > 0); @@ -990,7 +989,6 @@ int calc_thread_work_size( } else { return 8; } - return io_size; #endif } diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h index d5aeb21584abe..7961af63cced2 100644 --- a/aten/src/ATen/native/cuda/jit_utils.h +++ b/aten/src/ATen/native/cuda/jit_utils.h @@ -204,9 +204,6 @@ template <> inline std::string typeName(){ template <> inline std::string typeName>(){ return "std::complex"; } -template <> inline std::string typeName>(){ - return "std::complex"; -} template <> inline std::string typeName>(){ return "std::complex"; } diff --git a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp index cd87850fe9eb9..6b34c0c7d8b4a 100644 --- a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp @@ -23,7 +23,6 @@ Tensor& _sparse_mm_mkl_( #else TORCH_CHECK(false, "sparse_mm_mkl: ATen not compiled with MKL support"); #endif - return self; // for stopping compiler warnings. } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/xpu/Conv.cpp b/aten/src/ATen/native/mkldnn/xpu/Conv.cpp index eb8f154dedaf0..0553aac30e9eb 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Conv.cpp @@ -384,9 +384,12 @@ Tensor _convolution_out( at::MemoryFormat mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input.ndimension()) : at::MemoryFormat::Contiguous; - auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r; - input = input.contiguous(mfmt); - weight = weight.contiguous(mfmt); + + auto bias = bias_r.defined() + ? make_contiguous_and_aligned(bias_r) + : bias_r; + input = make_contiguous_and_aligned(input, mfmt); + weight = make_contiguous_and_aligned(weight, mfmt); check_shape_forward(input, weight, bias, params, true); Tensor output; @@ -591,9 +594,9 @@ std::tuple convolution_backward_overrideable( auto mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input_.ndimension()) : at::MemoryFormat::Contiguous; - grad_output_ = grad_output_.contiguous(mfmt); - weight_ = weight_.contiguous(mfmt); - input_ = input_.contiguous(mfmt); + grad_output_ = make_contiguous_and_aligned(grad_output_, mfmt); + weight_ = make_contiguous_and_aligned(weight_, mfmt); + input_ = make_contiguous_and_aligned(input_, mfmt); auto opt = grad_output_.options(); Tensor grad_input; diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp index a8a6b870ff6b6..b5bcb9b105184 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp @@ -287,6 +287,25 @@ void undo_broadcast(at::Tensor& tensor) { return; } +bool is_64_bytes_aligned(const at::Tensor& tensor) { + constexpr uintptr_t alignment_byte = 64; + return reinterpret_cast(tensor.data_ptr()) % alignment_byte == 0; +} + +at::Tensor make_contiguous_and_aligned( + const at::Tensor& tensor, + std::optional memory_format) { + at::Tensor out = memory_format.has_value() + ? tensor.contiguous(*memory_format) + : tensor.contiguous(); + + if (out.storage_offset() > 0 && !is_64_bytes_aligned(out)) { + out = out.clone(); + } + + return out; +} + bool is_onednn_matmul_strides(const at::Tensor& tensor) { // https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html // oneDNN matmul only support 2-dim and 3-dim @@ -300,11 +319,8 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) { if (tensor.is_contiguous()) return true; - if (tensor.storage_offset() > 0) { - // currently onednn asks 64 byte alignment - constexpr int alignment_byte = 64; - if (reinterpret_cast(tensor.data_ptr()) % alignment_byte > 0) - return false; + if (tensor.storage_offset() > 0 && !is_64_bytes_aligned(tensor)) { + return false; } // the overlapped cases are not supported diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h index 0055fc2f296ad..a37d57a87499c 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -46,6 +47,12 @@ void undo_broadcast(at::Tensor& tensor); bool is_onednn_matmul_strides(const at::Tensor& tensor); +bool is_64_bytes_aligned(const at::Tensor& tensor); + +at::Tensor make_contiguous_and_aligned( + const at::Tensor& tensor, + std::optional memory_format = std::nullopt); + bool is_broadcast_from_other_to_self( const at::Tensor& self, const at::Tensor& other); diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index c994bbeadbfd1..eb7deb62a9042 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -122,6 +122,20 @@ struct logaddexp2_functor { } }; +struct xlogy_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return static_cast(c10::metal::xlogy(a, b)); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::xlogy(float(a), float(b)); + } + inline float operator()(const bool a, const bool b) { + return (a && !b) ? -INFINITY : 0; + } +}; + struct xlog1py_functor { template , bool> = true> inline T operator()(const T a, const T b) { @@ -449,6 +463,8 @@ REGISTER_FLOAT_BINARY_OP(logaddexp); REGISTER_INT2FLOAT_BINARY_OP(logaddexp); REGISTER_FLOAT_BINARY_OP(logaddexp2); REGISTER_INT2FLOAT_BINARY_OP(logaddexp2); +REGISTER_FLOAT_BINARY_OP(xlogy); +REGISTER_INT2FLOAT_BINARY_OP(xlogy); REGISTER_FLOAT_BINARY_OP(xlog1py); REGISTER_INT2FLOAT_BINARY_OP(xlog1py); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 03ecee79449c0..66190e185f3dc 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -109,6 +109,10 @@ static void logaddexp2_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "logaddexp2"); } +static void xlogy_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "xlogy"); +} + static void xlog1py_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types"); lib.exec_binary_kernel(iter, "xlog1py"); @@ -242,6 +246,7 @@ static void gcd_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel) REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel); REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel); +REGISTER_DISPATCH(xlogy_stub, &xlogy_mps_kernel) REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 35ebff2a8d4bd..e56dd92679507 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -26,7 +26,6 @@ #include #include #include -#include #endif namespace at::native { @@ -262,29 +261,4 @@ static void add_sub_lerp_template(const Tensor& self, } } -TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; - MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor name:nil]; - MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor name:nil]; - MPSGraphTensor* xlogyTensor = [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor - secondaryTensor:logyTensor - name:nil]; - MPSGraphTensor* xEqualZeroPredicateTensor = [mpsGraph equalWithPrimaryTensor:primaryCastTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:xEqualZeroPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:xlogyTensor - name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor:yIsNaNPredicateTensor - truePredicateTensor:secondaryCastTensor - falsePredicateTensor:outputTensor - name:nil]; - return outputTensor; - }; - mps::binaryOpTensor(self, other, output, "xlogy_out_mps", xlogy_op_block); -} - } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4cdb992b469d5..9edab6b2d3552 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -480,6 +480,7 @@ CompositeExplicitAutograd: _conj_physical SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr autogen: _conj_physical.out + tags: pointwise - func: conj_physical(Tensor self) -> Tensor variants: function, method @@ -3350,11 +3351,13 @@ dispatch: CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps + XPU: _fused_rms_norm_xpu CompositeImplicitAutograd: rms_norm_composite - func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: _fused_rms_norm_backward_cuda + XPU: _fused_rms_norm_backward_xpu - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method @@ -3691,8 +3694,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: xlogy_out - MPS: xlogy_out_mps + CPU, CUDA, MPS: xlogy_out tags: pointwise - func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -5381,6 +5383,7 @@ - func: selu_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator + tags: pointwise - func: celu(Tensor self, Scalar alpha=1.0) -> Tensor device_check: NoCheck # TensorIterator @@ -5393,6 +5396,7 @@ dispatch: CompositeExplicitAutograd: celu_ autogen: celu.out + tags: pointwise - func: silu(Tensor self) -> Tensor structured_delegate: silu.out @@ -12129,6 +12133,7 @@ structured_delegate: elu.out device_check: NoCheck # TensorIterator python_module: nn + tags: pointwise - func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -12190,6 +12195,7 @@ structured_delegate: hardsigmoid.out device_check: NoCheck # TensorIterator python_module: nn + tags: pointwise - func: hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True @@ -12235,6 +12241,7 @@ dispatch: CPU, CUDA, MPS: hardtanh_ QuantizedCPU: hardtanh_quantized_cpu_ + tags: pointwise - func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -12275,7 +12282,7 @@ python_module: nn dispatch: QuantizedCPU: leaky_relu_quantized_cpu - tags: core + tags: [core, pointwise] - func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True @@ -12294,6 +12301,7 @@ python_module: nn dispatch: QuantizedCPU: leaky_relu_quantized_cpu_ + tags: pointwise - func: log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 318bbb3728a85..120af06ca840a 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -615,7 +615,6 @@ Tensor squeeze_nested(const Tensor& self) { "squeeze(): For nested tensors, squeeze without the dim argument is not supported ", "at the moment, however you can use squeeze(Tensor self, int dim) instead ", "if you need this feature, please open an issue on github describing your use case."); - return self; } Tensor squeeze_dim_nested(const Tensor& self, IntArrayRef dims) { @@ -1022,6 +1021,7 @@ static Tensor cat_nested_as_jagged( const auto first_item_dim = first_item.dim(); const auto first_item_batch_size = first_item.size(0); std::vector jagged_views; + jagged_views.reserve(tensors.size()); for (auto i : c10::irange(tensors.size())) { auto t = tensors[i].get(); TORCH_CHECK(t.is_nested(), @@ -1073,6 +1073,8 @@ static Tensor cat_nested_impl( // handle simple case of dim=0: concat NT components std::vector buffers; std::vector sizes; + buffers.reserve(tensors.size()); + sizes.reserve(tensors.size()); for (const auto i : c10::irange(tensors.size())) { const Tensor& t = tensors[i]; TORCH_CHECK( diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 47b8e2ad8086f..43556def235ec 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -433,6 +433,8 @@ inline Tensor wrap_tensor_node( } else { // Slow path std::vector flat_tensors; std::vector sizes; + flat_tensors.reserve(tensor_node.degree()); + sizes.reserve(tensor_node.degree()); for (const auto i : c10::irange(tensor_node.degree())) { flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); sizes.push_back( diff --git a/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp b/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp index 5df69d01b2549..36baffc36c1ed 100644 --- a/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp +++ b/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp @@ -102,10 +102,11 @@ Tensor channel_shuffle_quantized_cpu( int64_t groups) { #ifdef USE_PYTORCH_QNNPACK return quantized_channel_shuffle_impl(self, groups); -#endif +#else // If QNNPACK is not available then fall back to the // non quantized path. return at::native::channel_shuffle(self, groups); +#endif } // Keep the registry in the anonymous namespace. diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp index b115b25c42784..333796e8a24ce 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp @@ -94,7 +94,6 @@ at::Tensor PackedEmbeddingBagWeight::unpack() { TORCH_INTERNAL_ASSERT( false, "We currently only support 8-bit and 4-bit quantization of embedding_bag."); - return weight_origin; } namespace at::native { diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index 7fe44de11e54c..857d0c52140e9 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -65,7 +65,6 @@ Tensor adaptive_avg_pool2d_quantized_cuda( return at::quantize_per_tensor(result_fp32, input.q_scale(), input.q_zero_point(), input.scalar_type()); #else // USE_CUDA TORCH_CHECK(false, "at::native::adaptive_avg_pool2d_quantized_cuda: ATen not compiled with USE_CUDA support"); - return Tensor{}; // never reached, placates the compiler #endif } @@ -214,7 +213,6 @@ Tensor quantized_max_pool2d_cudnn( #endif // AT_CUDNN_ENABLED() #else // USE_CUDA TORCH_CHECK(false, "at::native::quantized_max_pool2d_cudnn: ATen not compiled with USE_CUDA support"); - return Tensor{}; // never reached, placates the compiler #endif } diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index edb95c23b98ba..cb8baf89a86e2 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -731,8 +731,8 @@ Tensor& addmm_out_sparse_compressed_cpu( " without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors."); #else sparse::impl::mkl::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result); -#endif return result; +#endif } Tensor addmm_sparse_compressed_dense( diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 5dec3746eaa88..b18b4dfccabac 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -95,7 +95,6 @@ bool is_coalesced_sparse(const SparseTensor& self) { bool is_coalesced_default(const Tensor& self) { TORCH_CHECK(false, "is_coalesced expected sparse coordinate tensor layout but got ", self.layout()); - return false; } int64_t _nnz_sparse(const SparseTensor& self) { diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index e04b64cf4efc1..7c684e84cf178 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -854,7 +854,6 @@ Tensor scaled_dot_product_attention( TORCH_CHECK( false, "No viable backend for scaled_dot_product_attention was found."); - return Tensor(); } } @@ -1064,11 +1063,6 @@ _scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor & philox_offset, std::optional scale) { TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable_backward not implemented: This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function "); - return std::tuple( - at::empty_like(query), - at::empty_like(key), - at::empty_like(value), - at::empty_like(attn_bias)); } Tensor triton_multi_head_attention( diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index f90d5beeb60be..79b5df3f302bb 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -937,8 +937,8 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { if (dprop->major >= 8) { return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug); } -#endif return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); +#endif } SDPBackend select_sdp_backend(sdp_params const& kernel_params) { diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index aeb25b56e60e9..a5a0d93cccc3d 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -631,6 +631,7 @@ Tensor run_quantized_addmm_context( return output; } else { std::vector shape; + shape.reserve(static_cast(std::max(0, input_arg.dim()))); for (const auto i : c10::irange(input_arg.dim() - 1)) { shape.emplace_back(input_arg.size(i)); } @@ -751,6 +752,7 @@ Tensor run_addmm_context( return output; } else { std::vector shape; + shape.reserve(static_cast(std::max(0, input_arg.dim()))); for (const auto i : c10::irange(input_arg.dim() - 1)) { shape.emplace_back(input_arg.size(i)); } diff --git a/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp b/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp index 1ec6957162cbb..338363c49cdbe 100644 --- a/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp @@ -76,6 +76,7 @@ std::tuple native_layer_norm( const Tensor bias = bias_opt->is_vulkan() ? *bias_opt : bias_opt->vulkan(); std::vector dims_to_reduce; + dims_to_reduce.reserve(normalized_shape.size()); for (const auto i : c10::irange(normalized_shape.size())) { dims_to_reduce.push_back(input_arg.dim() - i - 1); } diff --git a/aten/src/ATen/native/vulkan/ops/Repeat.cpp b/aten/src/ATen/native/vulkan/ops/Repeat.cpp index 94c29b55d0571..0c08d7f7c26dc 100644 --- a/aten/src/ATen/native/vulkan/ops/Repeat.cpp +++ b/aten/src/ATen/native/vulkan/ops/Repeat.cpp @@ -37,6 +37,8 @@ Tensor repeat(const Tensor& self, const IntArrayRef repeats) { std::vector tensor_seq_to_concat; for (const auto i : c10::irange(out_ndims)) { + tensor_seq_to_concat.reserve( + static_cast(std::max(0, repeats[i]))); for (const auto k : c10::irange(repeats[i])) { (void)k; tensor_seq_to_concat.emplace_back(tensor_to_repeat.clone()); diff --git a/c10/core/ScalarType.cpp b/c10/core/ScalarType.cpp index 90741b45bcfcb..24cf425e41d18 100644 --- a/c10/core/ScalarType.cpp +++ b/c10/core/ScalarType.cpp @@ -19,11 +19,10 @@ constexpr auto c4 = ScalarType::ComplexFloat; constexpr auto c8 = ScalarType::ComplexDouble; constexpr auto b1 = ScalarType::Bool; constexpr auto bf = ScalarType::BFloat16; -constexpr auto cb = ScalarType::BComplex32; constexpr auto ud = ScalarType::Undefined; constexpr auto index2dtype = array_of< - c10::ScalarType>(u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf, cb); + c10::ScalarType>(u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf); constexpr std::array(ScalarType::NumOptions)> calculate_dtype2index() { @@ -110,21 +109,20 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) { static constexpr std:: array, index2dtype.size()> _promoteTypesLookup = {{ - /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf cb */ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf, cb}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf, cb}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf, cb}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf, cb}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf, cb}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4, c4}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4, c4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8, c8}, - /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4, c4}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf, cb}, - /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf, cb}, - /* cb */ {cb, cb, cb, cb, cb, c4, c4, c8, c4, c4, c8, cb, cb, cb}, + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, }}; // clang-format on return _promoteTypesLookup[ix_a][ix_b]; @@ -188,8 +186,6 @@ std::pair getDtypeNames(c10::ScalarType scalarType) { return std::make_pair("float16", "half"); case c10::ScalarType::ComplexHalf: return std::make_pair("complex32", "chalf"); - case c10::ScalarType::BComplex32: - return std::make_pair("bcomplex32", ""); case c10::ScalarType::ComplexFloat: return std::make_pair("complex64", "cfloat"); case c10::ScalarType::ComplexDouble: diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 840d0705c7b77..dc9d168f053e7 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -114,7 +114,7 @@ inline bool isFloatingType(ScalarType t) { inline bool isComplexType(ScalarType t) { return ( t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || - t == ScalarType::ComplexDouble || t == ScalarType::BComplex32); + t == ScalarType::ComplexDouble); } inline bool isBitsType(ScalarType t) { @@ -187,7 +187,6 @@ inline bool isSignedType(ScalarType t) { CASE_ISSIGNED(ComplexHalf); CASE_ISSIGNED(ComplexFloat); CASE_ISSIGNED(ComplexDouble); - CASE_ISSIGNED(BComplex32); CASE_ISSIGNED(Bool); case ScalarType::Int1: case ScalarType::Int2: @@ -230,8 +229,6 @@ inline ScalarType toRealValueType(ScalarType t) { switch (t) { case ScalarType::ComplexHalf: return ScalarType::Half; - case ScalarType::BComplex32: - return ScalarType::BFloat16; case ScalarType::ComplexFloat: return ScalarType::Float; case ScalarType::ComplexDouble: @@ -244,7 +241,9 @@ inline ScalarType toRealValueType(ScalarType t) { inline ScalarType toComplexType(ScalarType t) { switch (t) { case ScalarType::BFloat16: - return ScalarType::BComplex32; + // BFloat16 has range equivalent to Float, + // so we map it to ComplexFloat. + return ScalarType::ComplexFloat; case ScalarType::Half: return ScalarType::ComplexHalf; case ScalarType::Float: @@ -253,8 +252,6 @@ inline ScalarType toComplexType(ScalarType t) { return ScalarType::ComplexDouble; case ScalarType::ComplexHalf: return ScalarType::ComplexHalf; - case ScalarType::BComplex32: - return ScalarType::BComplex32; case ScalarType::ComplexFloat: return ScalarType::ComplexFloat; case ScalarType::ComplexDouble: diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index cd9c9b86285c4..4e6097a406bc2 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -34,7 +34,7 @@ class C10_CUDA_API CUDAAllocatorConfig { static bool expandable_segments() { bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: use_expandable_segments(); -#if !defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && !defined(USE_ROCM) +#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED if (enabled) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 6e94a5f694bf8..eb628b51968a4 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -17,17 +17,11 @@ #include #include -#if defined(PYTORCH_C10_DRIVER_API_SUPPORTED) || defined(USE_ROCM) -#if defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include -#endif -#ifndef _WIN32 #include #include #include -#else -#include -#endif #endif #include @@ -275,8 +269,7 @@ struct SegmentRange { SegmentRange(void* p, size_t s) : ptr(static_cast(p)), size(s) {} }; -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) || \ - defined(USE_ROCM) +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) /* Note [Expandable Segments] @@ -390,13 +383,8 @@ struct ExpandableSegment { // This allows for some cases where we have to unmap pages earlier in the // segment to put them at the end. max_handles_ = numSegments(prop.totalGlobalMem + prop.totalGlobalMem / 8); -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemAddressReserve( - &ptr_, segment_size_ * max_handles_, 0ULL, 0, 0ULL)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemAddressReserve_( &ptr_, segment_size_ * max_handles_, 0ULL, 0, 0ULL)); -#endif } ExpandableSegment(const ExpandableSegment&) = delete; ExpandableSegment(ExpandableSegment&&) = delete; @@ -420,14 +408,12 @@ struct ExpandableSegment { // if it fails, use posix file handle if (CUDAAllocatorConfig::expandable_segments_handle_type() == Expandable_Segments_Handle_Type::UNSPECIFIED) { -#ifndef USE_ROCM CUDAAllocatorConfig::set_expandable_segments_handle_type( Expandable_Segments_Handle_Type::FABRIC_HANDLE); auto output = map(range); if (output.ptr != nullptr) { return output; } -#endif // if fabric handle is not supported, use posix file handle. CUDAAllocatorConfig::set_expandable_segments_handle_type( Expandable_Segments_Handle_Type::POSIX_FD); @@ -465,48 +451,27 @@ struct ExpandableSegment { } } int flag = 0; -#ifndef USE_ROCM C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuDeviceGetAttribute_( &flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, device_)); -#endif if (flag) prop.allocFlags.gpuDirectRDMACapable = 1; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; // NOLINTNEXTLINE(bugprone-signed-char-misuse) prop.location.id = static_cast(device_); -#ifdef USE_ROCM - auto status = hipMemCreate(&handle, segment_size_, &prop, 0); -#else auto status = DriverAPI::get()->cuMemCreate_(&handle, segment_size_, &prop, 0); -#endif if (status != CUDA_SUCCESS) { if (status == CUDA_ERROR_OUT_OF_MEMORY) { -#ifdef USE_ROCM - // hipMemCreate above returned hipErrorOutOfMemory and treated it - // like a sticky runtime error. Which means we need to clear it. - // Unlike the corresponding CUDA Driver API. - (void)hipGetLastError(); -#endif for (auto j : c10::irange(begin, i)) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) auto h = handles_.at(j).value(); handles_.at(j) = std::nullopt; -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemRelease(h.handle)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle)); -#endif } trimHandles(); return rangeFromHandles(begin, begin); -#ifdef USE_ROCM - } else { - C10_CUDA_CHECK(status); - } -#else } else if ( CUDAAllocatorConfig::expandable_segments_handle_type() == Expandable_Segments_Handle_Type::FABRIC_HANDLE) { @@ -522,7 +487,6 @@ struct ExpandableSegment { } else { C10_CUDA_DRIVER_CHECK(status); } -#endif } handles_.at(i) = Handle{handle, std::nullopt}; } @@ -558,11 +522,7 @@ struct ExpandableSegment { // thereby ensuring that the handle can be correctly matched in // ipcMemHandle_to_devptr. ShareHeader header{}; -#ifdef _WIN32 - header.pid = _getpid(); -#else header.pid = getpid(); -#endif header.segment_size = segment_size_; header.num_handles = end - begin; @@ -574,13 +534,8 @@ struct ExpandableSegment { Expandable_Segments_Handle_Type::FABRIC_HANDLE) { if (!handle.shareable_handle) { int fd = 0; -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemExportToShareableHandle( - &fd, handle.handle, hipMemHandleTypePosixFileDescriptor, 0)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_( &fd, handle.handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); -#endif handle.shareable_handle = fd; LOG(INFO) << "use posix fd to share expandable segments."; } @@ -591,10 +546,6 @@ struct ExpandableSegment { reinterpret_cast(&*handle.shareable_handle), sizeof(int)); } else { -#ifdef USE_ROCM - TORCH_INTERNAL_ASSERT( - false, "expandable segment with fabric handle not supported"); -#else if (!handle.shareable_handle) { CUmemFabricHandle fabric_handle; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_( @@ -608,7 +559,6 @@ struct ExpandableSegment { buf.write( reinterpret_cast(&*handle.shareable_handle), sizeof(CUmemFabricHandle)); -#endif } } return rangeFromHandles(begin, end); @@ -624,20 +574,14 @@ struct ExpandableSegment { device, std::nullopt, header.segment_size, std::move(peers)); // older build setups (e.g. multiwheels) do not have this syscall, added 2020 // but the kernel on the system might still support it. -#ifndef _WIN32 #ifndef SYS_pidfd_open #define SYS_pidfd_open 434 #endif #ifndef SYS_pidfd_getfd #define SYS_pidfd_getfd 438 #endif -#endif // !_WIN32 if (CUDAAllocatorConfig::expandable_segments_handle_type() != Expandable_Segments_Handle_Type::FABRIC_HANDLE) { -#ifdef _WIN32 - TORCH_CHECK( - false, "IPC expandable segments are not supported on Windows"); -#else auto pidfd = syscall(SYS_pidfd_open, header.pid, 0); TORCH_CHECK( pidfd != -1 || errno != ENOSYS, @@ -653,13 +597,9 @@ struct ExpandableSegment { auto err = errno; close(static_cast(pidfd)); for (auto& h : segment->handles_) { -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemRelease(h.value().handle)); -#else C10_CUDA_DRIVER_CHECK( // NOLINTNEXTLINE(bugprone-unchecked-optional-access) DriverAPI::get()->cuMemRelease_(h.value().handle)); -#endif h = std::nullopt; } TORCH_CHECK( @@ -669,33 +609,17 @@ struct ExpandableSegment { TORCH_CHECK(false, "pidfd_getfd: ", c10::utils::str_error(err)); } CUmemGenericAllocationHandle handle = 0; -#ifdef USE_ROCM -#if ROCM_VERSION >= 70100 - void* myfd_handle = - reinterpret_cast(static_cast(myfd)); -#else - void* myfd_handle = (void*)(uintptr_t)&myfd; -#endif - C10_CUDA_CHECK(hipMemImportFromShareableHandle( - &handle, myfd_handle, hipMemHandleTypePosixFileDescriptor)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_( &handle, // NOLINTNEXTLINE(performance-no-int-to-ptr) (void*)(uintptr_t)myfd, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); -#endif LOG(INFO) << "use posix fd to import expandable segments."; close(static_cast(myfd)); segment->handles_.emplace_back(Handle{handle, std::nullopt}); } close(static_cast(pidfd)); -#endif // !_WIN32 } else { -#ifdef USE_ROCM - TORCH_INTERNAL_ASSERT( - false, "expandable segment with fabric handle not supported"); -#else for (auto i : c10::irange(header.num_handles)) { (void)i; CUmemFabricHandle fabric_handle; @@ -710,7 +634,6 @@ struct ExpandableSegment { LOG(INFO) << "use fabric handle to import expandable segments."; segment->handles_.emplace_back(Handle{handle, std::nullopt}); } -#endif } segment->mapAndSetAccess(0, header.num_handles); return segment; @@ -746,12 +669,8 @@ struct ExpandableSegment { ~ExpandableSegment() { forEachAllocatedRange( [&](size_t begin, size_t end) { unmapHandles(begin, end); }); -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemAddressFree(ptr_, segment_size_ * max_handles_)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemAddressFree_( ptr_, segment_size_ * max_handles_)); -#endif } private: @@ -761,28 +680,12 @@ struct ExpandableSegment { // NOLINTNEXTLINE(bugprone-signed-char-misuse) desc.location.id = static_cast(device); desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemSetAccess( - ptr() + begin * segment_size_, - (end - begin) * segment_size_, - &desc, - 1)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemSetAccess_( ptr_ + begin * segment_size_, (end - begin) * segment_size_, &desc, 1)); -#endif } void mapAndSetAccess(size_t begin, size_t end) { for (auto i : c10::irange(begin, end)) { -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemMap( - ptr() + i * segment_size_, - segment_size_, - 0, - handles_.at(i).value().handle, - 0ULL)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemMap_( ptr_ + i * segment_size_, segment_size_, @@ -790,7 +693,6 @@ struct ExpandableSegment { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) handles_.at(i).value().handle, 0ULL)); -#endif } mapped_size_ += (end - begin) * segment_size_; setAccess(device_, begin, end); @@ -817,22 +719,12 @@ struct ExpandableSegment { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) Handle h = handles_.at(i).value(); handles_.at(i) = std::nullopt; -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemUnmap(ptr() + segment_size_ * i, segment_size_)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_( ptr_ + segment_size_ * i, segment_size_)); -#endif if (h.shareable_handle) { -#ifndef _WIN32 close(std::get(*h.shareable_handle)); -#endif } -#ifdef USE_ROCM - C10_CUDA_CHECK(hipMemRelease(h.handle)); -#else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle)); -#endif } trimHandles(); } @@ -880,11 +772,7 @@ struct ExpandableSegment { std::optional> shareable_handle; }; struct ShareHeader { -#ifdef _WIN32 - int pid; -#else pid_t pid; -#endif size_t segment_size; size_t num_handles; }; diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 8d6131d867068..ead38cbfff1ec 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -723,6 +723,19 @@ inline T logaddexp2(T a, T b) { } } +template +inline float xlogy(T x, T y) { + if (::metal::isnan(y)) { + return NAN; + } + + if (x == 0) { + return x; + } + + return x * precise::log(float(y)); +} + template inline float xlog1py(T x, T y) { if (::metal::isnan(y)) { diff --git a/c10/util/Semaphore.h b/c10/util/Semaphore.h index 041d9abecf515..6d2235c09da5a 100644 --- a/c10/util/Semaphore.h +++ b/c10/util/Semaphore.h @@ -8,7 +8,13 @@ // note: __cpp_lib_semaphore will not be defined in some apple platforms // even if >= C++20. -#if __has_include() && defined(__cpp_lib_semaphore) && __cpp_lib_semaphore >= 201907L +// +// libstdc++'s __atomic_semaphore has a lost-wakeup bug: _M_release skips +// the futex notify when the counter is already positive, but a concurrent +// _S_do_try_acquire can fail its CAS, see zero, and block — missing the +// wakeup. https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98033 +#if __has_include() && defined(__cpp_lib_semaphore) && \ + __cpp_lib_semaphore >= 201907L && !defined(__GLIBCXX__) #define C10_SEMAPHORE_USE_STL #endif diff --git a/c10/util/TypeCast.h b/c10/util/TypeCast.h index bb42b6aac7cf8..d8a92c2eaa8c2 100644 --- a/c10/util/TypeCast.h +++ b/c10/util/TypeCast.h @@ -186,142 +186,6 @@ struct static_cast_with_inter_type< } }; -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::Float8_e5m2> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::Float8_e5m2 src) { - return static_cast>(c10::complex{src}); - } -}; - -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::Float8_e5m2fnuz> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::Float8_e5m2fnuz src) { - return static_cast>(c10::complex{src}); - } -}; - -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::Float8_e4m3fn> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::Float8_e4m3fn src) { - return static_cast>(c10::complex{src}); - } -}; - -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::Float8_e4m3fnuz> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::Float8_e4m3fnuz src) { - return static_cast>(c10::complex{src}); - } -}; - -// TODO(#146647): Can we make all these template specialization happen -// based off our apply macros? -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::Float8_e8m0fnu> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::Float8_e8m0fnu src) { - return static_cast>(c10::complex{src}); - } -}; - -template <> -struct static_cast_with_inter_type, c10::Half> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::Half src) { - return static_cast>( - static_cast>(src)); - } -}; - -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::complex> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::complex src) { - return static_cast>( - static_cast>(src)); - } -}; - -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::complex> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::complex src) { - return static_cast>( - static_cast>(src)); - } -}; - -template <> -struct static_cast_with_inter_type< - c10::complex, - c10::complex> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::Half> - apply(c10::complex src) { - return static_cast>( - static_cast>(src)); - } -}; - -template <> -struct static_cast_with_inter_type> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::Half apply( - c10::complex src) { - return static_cast(static_cast(src.real())); - } -}; - -template <> -struct static_cast_with_inter_type> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::BFloat16 apply( - c10::complex src) { - return static_cast(static_cast(src.real())); - } -}; - -template <> -struct static_cast_with_inter_type> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::BFloat16 apply( - c10::complex src) { - return src.real(); - } -}; - -template <> -struct static_cast_with_inter_type, c10::BFloat16> { - C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< - c10::BFloat16> - apply(c10::BFloat16 src) { - return c10::complex{src, 0}; - } -}; - template C10_HOST_DEVICE To convert(From f) { return static_cast_with_inter_type::apply(f); diff --git a/docs/source/tensor_attributes.rst b/docs/source/tensor_attributes.rst index 64f4599e603e5..1f6c03281d643 100644 --- a/docs/source/tensor_attributes.rst +++ b/docs/source/tensor_attributes.rst @@ -27,7 +27,6 @@ dtype description ``torch.float16`` or ``torch.half`` 16-bit floating point, as defined in https://en.wikipedia.org/wiki/IEEE_754, S-E-M 1-5-10 ``torch.bfloat16`` 16-bit floating point, sometimes referred to as Brain floating point, S-E-M 1-8-7 ``torch.complex32`` or ``torch.chalf`` 32-bit complex with two `float16` components -``torch.bcomplex32`` [shell]_ 32-bit complex with two `bfloat16` components ``torch.complex64`` or ``torch.cfloat`` 64-bit complex with two `float32` components ``torch.complex128`` or ``torch.cdouble`` 128-bit complex with two `float64` components ``torch.float8_e4m3fn`` [shell]_, [1]_ 8-bit floating point, S-E-M 1-4-3, from https://arxiv.org/abs/2209.05433 diff --git a/scripts/build_host_protoc.sh b/scripts/build_host_protoc.sh old mode 100644 new mode 100755 diff --git a/test/cpp/aoti_abi_check/test_dtype.cpp b/test/cpp/aoti_abi_check/test_dtype.cpp index 053bcc22cccf6..e6e7e75867c8d 100644 --- a/test/cpp/aoti_abi_check/test_dtype.cpp +++ b/test/cpp/aoti_abi_check/test_dtype.cpp @@ -200,7 +200,6 @@ TEST(TestDtype, TestScalarType) { ScalarType::Int7, ScalarType::Float8_e8m0fnu, ScalarType::Float4_e2m1fn_x2, - ScalarType::BComplex32, ScalarType::Undefined, }; for (int8_t i = 0; i < static_cast(torch::headeronly::NumScalarTypes); diff --git a/test/cpp/aoti_abi_check/test_scalartype.cpp b/test/cpp/aoti_abi_check/test_scalartype.cpp index d81b62c9909ea..6df242b5a4cec 100644 --- a/test/cpp/aoti_abi_check/test_scalartype.cpp +++ b/test/cpp/aoti_abi_check/test_scalartype.cpp @@ -41,8 +41,8 @@ TEST(TestScalarType, CppTypeToScalarType) { } TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14) -TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 19) -TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 47) +TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18) +TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46) TEST_FORALL(AT_FORALL_INT_TYPES, 5) TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7) TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, ) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index da66be7ce0f6f..0d9daab6be733 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -13,6 +13,8 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from torch._C._autograd import DeviceType +from torch._C._distributed_c10d import _SymmetricMemory from torch.distributed._composable import checkpoint, replicate from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, @@ -44,8 +46,10 @@ from torch.distributed.tensor import DTensor from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.experimental import implicit_replication +from torch.testing._internal.common_cuda import SM90OrLater, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( - requires_multicast_support, + MultiProcContinuousTest, + PLATFORM_SUPPORTS_SYMM_MEM, skip_if_lt_x_gpu, ) from torch.testing._internal.common_fsdp import ( @@ -59,7 +63,11 @@ patch_unshard, ) from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + requires_cuda_p2p_access, run_tests, + skip_but_pass_in_sandcastle_if, TEST_WITH_ROCM, TEST_XPU, xfailIf, @@ -70,6 +78,7 @@ Transformer, TransformerBlock, ) +from torch.testing._internal.inductor_utils import skipCUDAIf c10d_ops = torch.ops.c10d @@ -1638,8 +1647,15 @@ def _run(cls, *args, **kwargs): @skip_if_lt_x_gpu(2) # The NCCL PG refuses to allocate tensors if multicast is unavailable, see # https://github.com/pytorch/pytorch/blob/503362d019b3782581492af7767945dbd75ca1c9/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L5634 - @requires_multicast_support() def test_fully_shard_alloc_from_pg(self): + # Run this check inside test instead of using @requires_multicast_support(). + # The decorator would trigger an initialization of SymmMem allocator + # when Python statically initializes classes in this file, causing + # SymmMem to fix the allocate backend to "CUDA". This is unfriendly for + # other tests in this file that requires NCCL backend + if not _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0): + self.skipTest("multicast support is not available") + torch.manual_seed(42) model_args = ModelArgs() model = Transformer(model_args) @@ -1691,6 +1707,64 @@ def test_exception_when_used_together_with_comm_hooks(self): model.set_allocate_memory_from_process_group_for_comm(True) +@requires_cuda_p2p_access() +@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Not enough GPUs to run the test") +@unittest.skipIf( + not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this platform" +) +@skipCUDAIf(TEST_WITH_ROCM, "requires NVIDIA GPUs") +@skipCUDAIf(not SM90OrLater, "requires sm90+") +class TestFullyShardSymmMem(MultiProcContinuousTest): + @classmethod + def backend_str(cls) -> str | None: + return "nccl" + + @classmethod + def opts(cls): + if not dist.is_nccl_available(): + return None + # Enable Zero-CTA policy for CE collectives + opts = dist.ProcessGroupNCCL.Options() + opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO + return opts + + @property + def device(self) -> torch.device: + return torch.device("cuda", self.rank) + + @parametrize("sum_reduction", [True, False]) + def test_fully_shard_symm_mem(self, sum_reduction: bool): + torch.manual_seed(42 + self.rank) + device = torch.device("cuda", self.rank) + torch.cuda.set_device(device) + seq_len = 64 + model_args = ModelArgs() + model_args.dim = 4096 + model_args.max_seq_len = seq_len + model = Transformer(model_args).to(device) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module) + module.set_force_sum_reduction_for_comms(sum_reduction) + module.set_symm_mem_for_comm() + fully_shard(model) + model.set_force_sum_reduction_for_comms(sum_reduction) + model.set_symm_mem_for_comm() + + bs = 4 + inp = torch.randint(0, model_args.vocab_size, (bs, seq_len), device=device) + + def run(): + loss = model(inp) + loss.sum().backward() + + run() + torch.cuda.synchronize(device) + + +instantiate_parametrized_tests(TestFullyShardSymmMem) + + class TestFullyShardForceSumReduction(FSDPTest): # The messages might change when we move to a different NCCL version. # Please update this test if it starts failing. diff --git a/test/distributed/_composable/fsdp/test_fully_shard_dtensor.py b/test/distributed/_composable/fsdp/test_fully_shard_dtensor.py new file mode 100644 index 0000000000000..894e27a1d4b60 --- /dev/null +++ b/test/distributed/_composable/fsdp/test_fully_shard_dtensor.py @@ -0,0 +1,445 @@ +# Owner(s): ["oncall: distributed"] + +import copy + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable import replicate +from torch.distributed._composable.replicate_with_fsdp import ( + replicate as replicate_with_fsdp, +) +from torch.distributed.fsdp import DataParallelMeshDims, fully_shard +from torch.distributed.tensor import ( + distribute_module, + distribute_tensor, + DTensor, + init_device_mesh, + Replicate, + Shard, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP +from torch.testing._internal.common_utils import run_tests + + +device_type = torch.device(get_devtype()) + + +def _tp_partition_fn(name, module, device_mesh): + """Partition Linear weights across the last mesh dim for TP.""" + if not isinstance(module, nn.Linear): + return + num_non_tp_dims = device_mesh.ndim - 1 + replicate_prefix = [Replicate()] * num_non_tp_dims + for param_name, param in list(module.named_parameters(recurse=False)): + if param_name == "weight": + if "in_proj" in name: + placements = replicate_prefix + [Shard(0)] + else: + placements = replicate_prefix + [Shard(1)] + else: + placements = replicate_prefix + [Replicate()] + dist_param = nn.Parameter( + distribute_tensor(param, device_mesh, placements), + requires_grad=param.requires_grad, + ) + module.register_parameter(param_name, dist_param) + + +def _tp_shard_fn(param): + """FSDP shard placement that avoids the existing TP shard dim.""" + if any(isinstance(p, Shard) and p.dim == 0 for p in param.placements): + return Shard(1) + return Shard(0) + + +class TestFullyShardDTensor(FSDPTest): + @property + def world_size(self): + return min(4, torch.cuda.device_count()) + + def _run_train_parity( + self, model, ref_model, dp_pg, mesh=None, num_iters=5, mlp_dim=16 + ): + optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) + torch.manual_seed(42 + dp_pg.rank() + 1) + for i in range(num_iters): + inp = torch.randn((2, mlp_dim), device=device_type) + ref_optim.zero_grad(set_to_none=(i % 2 == 0)) + ref_loss = ref_model(inp).sum() + ref_loss.backward() + ref_optim.step() + + optim.zero_grad(set_to_none=(i % 2 == 0)) + if mesh is not None: + # Use Replicate on all dims: each rank computes on the full + # input (like DDP). DTensor handles the TP-sharded params. + inp = DTensor.from_local( + inp, mesh, [Replicate()] * mesh.ndim, run_check=False + ) + loss = model(inp).sum() + loss.backward() + optim.step() + + self.assertEqual(ref_loss, loss) + + for (n1, p1), (n2, p2) in zip( + ref_model.named_parameters(), model.named_parameters(), strict=True + ): + p2_full = p2.full_tensor() if isinstance(p2, DTensor) else p2 + self.assertEqual(p1, p2_full, msg=f"Param mismatch: {n1} vs {n2}") + + @skip_if_lt_x_gpu(2) + def test_dtensor_train_parity(self): + """Train parity for FSDP/HSDP/DDP with DTensors on SPMD meshes.""" + ws = self.world_size + world_mesh = init_device_mesh( + device_type.type, (ws,), mesh_dim_names=("world",) + ) + # (sizes, names, dp_dims, use_tp, reshard, dp_pg_source, use_rep_fsdp) + cases = [ + # 1D: FSDP + ( + (ws,), + ("fsdp",), + DataParallelMeshDims(shard="fsdp"), + False, + True, + None, + False, + ), + # 1D: FSDP with reshard_after_forward=False + ( + (ws,), + ("fsdp0",), + DataParallelMeshDims(shard="fsdp0"), + False, + False, + None, + False, + ), + # 1D: DDP-only + ( + (ws,), + ("ddp",), + DataParallelMeshDims(replicate="ddp"), + False, + True, + None, + False, + ), + # 1D: replicate_with_fsdp + ( + (ws,), + ("ddp0",), + DataParallelMeshDims(replicate="ddp0"), + False, + True, + None, + True, + ), + ] + if ws >= 4: + cases.extend( + [ + # HSDP 2D + ( + (2, ws // 2), + ("rep", "shard"), + DataParallelMeshDims(shard="shard", replicate="rep"), + False, + True, + "world", + False, + ), + # Multi-shard FSDP + ( + (2, ws // 2), + ("dp0", "dp1"), + DataParallelMeshDims(shard=("dp0", "dp1")), + False, + True, + "world", + False, + ), + # FSDP+TP + ( + (2, ws // 2), + ("fsdp1", "tp"), + DataParallelMeshDims(shard="fsdp1"), + True, + True, + "fsdp1", + False, + ), + # FSDP+TP with reshard_after_forward=False + ( + (2, ws // 2), + ("fsdp2", "tp0"), + DataParallelMeshDims(shard="fsdp2"), + True, + False, + "fsdp2", + False, + ), + # HSDP+TP 3D + ( + (1, ws // 2, 2), + ("rep0", "fsdp3", "tp1"), + DataParallelMeshDims(shard="fsdp3", replicate="rep0"), + True, + True, + "fsdp3", + False, + ), + # Multi-dim replicate + ( + (1, ws // 2, 2), + ("ddp1", "ddp2", "fsdp4"), + DataParallelMeshDims(shard="fsdp4", replicate=("ddp1", "ddp2")), + False, + True, + "world", + False, + ), + ] + ) + mlp_dim = 16 + for sizes, names, dp_dims, use_tp, reshard, dp_pg_src, use_rep in cases: + with self.subTest( + names=names, use_tp=use_tp, reshard=reshard, use_rep=use_rep + ): + mesh = world_mesh._unflatten(0, sizes, names) + + torch.manual_seed(42) + model = MLP(mlp_dim, device=device_type) + ref_model = copy.deepcopy(model) + + partition_fn = _tp_partition_fn if use_tp else None + distribute_module(model, mesh, partition_fn) + + if use_rep: + replicate_with_fsdp(model, mesh=mesh, dp_mesh_dims=dp_dims) + else: + shard_fn = _tp_shard_fn if use_tp else None + fully_shard( + model, + mesh=mesh, + reshard_after_forward=reshard, + shard_placement_fn=shard_fn, + dp_mesh_dims=dp_dims, + ) + + if dp_pg_src is None: + dp_pg = mesh.get_group() + elif dp_pg_src == "world": + dp_pg = dist.group.WORLD + else: + dp_pg = mesh[dp_pg_src].get_group() + + replicate( + ref_model, + device_ids=[self.rank] if device_type.type != "cpu" else None, + process_group=dp_pg, + ) + + self._run_train_parity( + model, ref_model, dp_pg, mesh=mesh, mlp_dim=mlp_dim + ) + dist.barrier() + + @skip_if_lt_x_gpu(2) + def test_sharded_param_correctness_1d(self): + """Verify sharded param mesh and placements for FSDP on 1D mesh.""" + mesh = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + + model = MLP(16, device=device_type) + distribute_module(model, mesh) + fully_shard( + model, + mesh=mesh, + dp_mesh_dims=DataParallelMeshDims(shard="fsdp"), + ) + + for param in model.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.device_mesh, mesh) + self.assertEqual(len(param.placements), 1) + self.assertIsInstance(param.placements[0], Shard) + + @skip_if_lt_x_gpu(4) + def test_fsdp_tp_dtensor_sharded_params(self): + """Verify sharded param mesh and placements for FSDP+TP on 2D mesh.""" + dp_size = 2 + tp_size = self.world_size // dp_size + mesh = init_device_mesh( + device_type.type, + (dp_size, tp_size), + mesh_dim_names=("fsdp", "tp"), + ) + + mlp_dim = 16 + model = MLP(mlp_dim, device=device_type) + + def partition_fn(name, module, device_mesh): + if not isinstance(module, nn.Linear): + return + for param_name, param in list(module.named_parameters(recurse=False)): + if param_name == "weight": + if "in_proj" in name: + placements = [Replicate(), Shard(0)] + else: + placements = [Replicate(), Shard(1)] + else: + placements = [Replicate(), Replicate()] + dist_param = nn.Parameter( + distribute_tensor(param, device_mesh, placements), + requires_grad=param.requires_grad, + ) + module.register_parameter(param_name, dist_param) + + distribute_module(model, mesh, partition_fn) + + def shard_fn(param): + if any(isinstance(p, Shard) and p.dim == 0 for p in param.placements): + return Shard(1) + return Shard(0) + + fully_shard( + model, + mesh=mesh, + shard_placement_fn=shard_fn, + dp_mesh_dims=DataParallelMeshDims(shard="fsdp"), + ) + + for name, param in model.named_parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.device_mesh, mesh) + self.assertEqual(len(param.placements), 2) + if "in_proj.weight" in name: + # FSDP shards dim 1 (avoiding TP dim 0), TP shards dim 0 + self.assertIsInstance(param.placements[0], Shard) + self.assertEqual(param.placements[0].dim, 1) + self.assertIsInstance(param.placements[1], Shard) + self.assertEqual(param.placements[1].dim, 0) + elif "out_proj.weight" in name: + # FSDP shards dim 0 (default), TP shards dim 1 + self.assertIsInstance(param.placements[0], Shard) + self.assertEqual(param.placements[0].dim, 0) + self.assertIsInstance(param.placements[1], Shard) + self.assertEqual(param.placements[1].dim, 1) + elif "bias" in name: + # FSDP shards dim 0 (default), TP replicates + self.assertIsInstance(param.placements[0], Shard) + self.assertEqual(param.placements[0].dim, 0) + self.assertIsInstance(param.placements[1], Replicate) + + @skip_if_lt_x_gpu(2) + def test_validation_non_replicate_dp_placement(self): + """Error when a param has non-Replicate placement on the DP shard dim.""" + mesh = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + model = nn.Linear(16, 32, device=device_type) + # Distribute weight with Shard(0) on the FSDP dim + model.weight = nn.Parameter( + distribute_tensor(model.weight.data, mesh, [Shard(0)]), + requires_grad=True, + ) + model.bias = nn.Parameter( + distribute_tensor(model.bias.data, mesh, [Replicate()]), + requires_grad=True, + ) + with self.assertRaisesRegex(ValueError, "Expected Replicate"): + fully_shard( + model, + mesh=mesh, + dp_mesh_dims=DataParallelMeshDims(shard="fsdp"), + ) + + @skip_if_lt_x_gpu(2) + def test_validation_invalid_dim_names(self): + """Error when dp_mesh_dims references nonexistent mesh dim names.""" + mesh = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + model = MLP(16, device=device_type) + distribute_module(model, mesh) + with self.assertRaisesRegex(ValueError, "not found in mesh.mesh_dim_names"): + fully_shard( + model, + mesh=mesh, + dp_mesh_dims=DataParallelMeshDims(shard="nonexistent"), + ) + + @skip_if_lt_x_gpu(2) + def test_validation_mesh_mismatch(self): + """Error when param DTensor mesh differs from the mesh passed to fully_shard.""" + mesh1 = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + mesh2 = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + model = nn.Linear(16, 32, device=device_type) + # Distribute params on mesh1 but pass mesh2 to fully_shard + model.weight = nn.Parameter( + distribute_tensor(model.weight.data, mesh1, [Replicate()]), + requires_grad=True, + ) + model.bias = nn.Parameter( + distribute_tensor(model.bias.data, mesh1, [Replicate()]), + requires_grad=True, + ) + with self.assertRaisesRegex(ValueError, "same mesh"): + fully_shard( + model, + mesh=mesh2, + dp_mesh_dims=DataParallelMeshDims(shard="fsdp"), + ) + + def test_validation_at_least_one_required(self): + """Error when neither shard nor replicate is set.""" + with self.assertRaisesRegex(ValueError, "At least one of shard or replicate"): + DataParallelMeshDims() + + @skip_if_lt_x_gpu(2) + def test_validation_spmd_mesh_non_dtensor_params(self): + """Error when dp_mesh_dims is provided but params are not DTensors.""" + mesh = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + model = MLP(16, device=device_type) + # Do NOT call distribute_module -- params are plain tensors + with self.assertRaisesRegex(ValueError, "must be DTensors"): + fully_shard( + model, + mesh=mesh, + dp_mesh_dims=DataParallelMeshDims(shard="fsdp"), + ) + + @skip_if_lt_x_gpu(2) + def test_validation_reshard_after_forward_int_spmd(self): + """Error when reshard_after_forward is int with SPMD mesh.""" + mesh = init_device_mesh( + device_type.type, (self.world_size,), mesh_dim_names=("fsdp",) + ) + model = MLP(16, device=device_type) + distribute_module(model, mesh) + with self.assertRaisesRegex( + NotImplementedError, "reshard_after_forward as int" + ): + fully_shard( + model, + mesh=mesh, + reshard_after_forward=2, + dp_mesh_dims=DataParallelMeshDims(shard="fsdp"), + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index 7ca5577bba20a..b6f66ba97a52b 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -2,6 +2,7 @@ import io import sys +import traceback import torch import torch.distributed as dist @@ -13,6 +14,7 @@ ) from torch.distributed._shard.sharded_tensor.metadata import TensorProperties from torch.distributed.c10d_logger import _c10d_logger +from torch.distributed.checkpoint.api import _wrap_exception from torch.distributed.checkpoint.logger import _dcp_logger from torch.distributed.checkpoint.metadata import MetadataIndex from torch.distributed.checkpoint.utils import ( @@ -20,6 +22,7 @@ _DistWrapper, find_state_dict_object, ) +from torch.distributed.distributed_c10d import _object_to_tensor, _tensor_to_object from torch.testing._internal.common_utils import ( run_tests, TEST_WITH_DEV_DBG_ASAN, @@ -139,6 +142,44 @@ def test_dcp_logger(self): self.assertEqual(1, len(_c10d_logger.handlers)) +class TestWrapException(TestCase): + def test_wrap_exception_serializable_via_object_to_tensor(self): + """Verify _wrap_exception produces a result that _object_to_tensor can serialize. + + Python 3.13+ adds a _code attribute to FrameSummary containing + bytecode objects that cannot be pickled. _wrap_exception must + clear these so that _object_to_tensor (used by gather_object and + scatter_object_list) succeeds instead of raising + "TypeError: cannot pickle code objects". + """ + try: + raise ValueError("test error") + except ValueError as e: + wrapped = _wrap_exception(e) + + # _object_to_tensor / _tensor_to_object are what gather_object + # and scatter_object_list use to serialize objects across ranks. + # This would raise "TypeError: cannot pickle code objects" + # on Python 3.13+ without the fix. + byte_tensor, size = _object_to_tensor(wrapped, torch.device("cpu"), None) + restored = _tensor_to_object(byte_tensor, size.item(), None) + + self.assertIsInstance(restored[0], ValueError) + self.assertEqual(str(restored[0]), "test error") + self.assertIsInstance(restored[1], traceback.StackSummary) + self.assertGreater(len(restored[1]), 0) + + def test_wrap_exception_preserves_traceback_formatting(self): + """Verify that clearing _code does not break traceback formatting.""" + try: + raise RuntimeError("format test") + except RuntimeError as e: + wrapped = _wrap_exception(e) + + formatted = "".join(traceback.format_list(wrapped[1])) + self.assertIn("raise RuntimeError", formatted) + + class TestReaderView(TestCase): def setUp(self): buffer = io.BytesIO(bytearray(range(ord("A"), ord("Z") + 1))) diff --git a/test/distributed/tensor/debug/test_comm_mode.py b/test/distributed/tensor/debug/test_comm_mode.py index d122a9f716fcd..a8f22333a95d5 100644 --- a/test/distributed/tensor/debug/test_comm_mode.py +++ b/test/distributed/tensor/debug/test_comm_mode.py @@ -5,6 +5,7 @@ import torch.distributed._functional_collectives as funcol import torch.nn as nn from torch.distributed.tensor import DeviceMesh, DTensor, Shard +from torch.distributed.tensor._redistribute import use_min_cost_redistribution_plan from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import requires_accelerator_dist_backend from torch.testing._internal.common_utils import run_tests, TestCase @@ -105,7 +106,7 @@ def f(x, y): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with comm_mode: + with comm_mode, use_min_cost_redistribution_plan(): f(x_dtensor, y_dtensor) comm_counts = comm_mode.get_comm_counts() diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index 576abcfa820fe..fa3df1373cc59 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -632,9 +632,7 @@ def test_dtensor_new_empty_strided(self): # test backward new_empty_strided with sharding works correctly my_dtensor.to_local().sum().backward() local_tensor.sum().backward() - self.assertEqual( - my_dtensor.grad.full_tensor(), new_strided_dtensor.grad.full_tensor() - ) + self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad) self.assertEqual( my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(), local_tensor.grad, @@ -912,50 +910,6 @@ def backward(ctx, grad_out1, grad_out2): loss = out1.sum() loss.backward() - @with_comms - def test_assert_equal_dtensor(self): - mesh = self.build_device_mesh() - local = torch.randn(4, 4, device=self.device_type) - - dt1 = DTensor.from_local(local.clone(), mesh, [Replicate()]) - dt2 = DTensor.from_local(local.clone(), mesh, [Replicate()]) - - self.assertEqual(dt1, dt2) - torch.testing.assert_close(dt1, dt2) - - dt3 = DTensor.from_local( - torch.randn(4, 4, device=self.device_type), mesh, [Replicate()] - ) - with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"): - self.assertEqual(dt1, dt3) - - dt_shard = DTensor.from_local(local.clone(), mesh, [Shard(0)]) - with self.assertRaisesRegex(AssertionError, "DTensor placements do not match"): - self.assertEqual(dt1, dt_shard) - - with self.assertRaisesRegex( - TypeError, "Comparing a DTensor to a non-DTensor is ambiguous" - ): - self.assertEqual(dt1, local) - with self.assertRaisesRegex( - TypeError, "Comparing a DTensor to a non-DTensor is ambiguous" - ): - self.assertEqual(local, dt1) - with self.assertRaisesRegex( - TypeError, "Comparing a DTensor to a non-DTensor is ambiguous" - ): - torch.testing.assert_close(dt1, local) - - dt_scalar = DTensor.from_local( - torch.tensor(42.0, device=self.device_type), mesh, [Replicate()] - ) - with self.assertRaisesRegex( - TypeError, "Comparing a DTensor to a non-DTensor is ambiguous" - ): - self.assertEqual(dt_scalar, 42.0) - self.assertEqual(dt_scalar.full_tensor(), 42.0) - self.assertEqual(dt_scalar.to_local(), 42.0) - DTensorTestWithLocalTensor = create_local_tensor_test_class( DTensorTest, diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index 262f458f87e42..64fd06dcbf742 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -543,7 +543,18 @@ def forward(self, x): %item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {}) %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {}) %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {}) - %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {}) + %getitem : [num_users=3] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {}) + %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {}) + %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {}) + %sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem, 0), kwargs = {}) + %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {}) + %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {}) + %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {}) + %_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {}) + %ge_3 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {}) + %_assert_scalar_default_3 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_3, Runtime assertion failed for expression u1 >= 0 on node 'ge_3'), kwargs = {}) + %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 4), kwargs = {}) + %_assert_scalar_default_4 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u1 <= 4 on node 'le_1'), kwargs = {}) return (getitem,)""", # noqa: B950 ) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 57710a332c2cb..577bdb4f9522a 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -605,12 +605,7 @@ def forward(self, tokens): with CommDebugMode() as comm_mode: output_dist = model_dist(x) - output_dist_cmp = ( - output_dist.full_tensor() - if isinstance(output_dist, DTensor) - else output_dist - ) - self.assertEqual(output_local, output_dist_cmp) + self.assertEqual(output_local, output_dist) # all requires_grad patterns should have the same forward comm counts expected_fwd_comm = { @@ -650,12 +645,7 @@ def forward(self, tokens): comm_mode.comm_module_counts["Global"]["backward"], expected_bwd_comm, ) - output_dist_cmp = ( - output_dist.full_tensor() - if isinstance(output_dist, DTensor) - else output_dist - ) - self.assertEqual(output_local, output_dist_cmp) + self.assertEqual(output_local, output_dist) except Exception as e: subtest_fails[subtest_cfg] = e @@ -1284,9 +1274,7 @@ def test_partial_reduction_ops(self): dt = dt.redistribute(dt.device_mesh, placements=[Replicate()]) out_with_redistribute = torch.norm(dt) - self.assertEqual( - out_without_redistribute.full_tensor(), out_with_redistribute.full_tensor() - ) + self.assertEqual(out_without_redistribute, out_with_redistribute) local_tensor = torch.rand(3, dtype=torch.float32, device=self.device_type) dt = DTensor.from_local( @@ -1297,9 +1285,7 @@ def test_partial_reduction_ops(self): dt = dt.redistribute(dt.device_mesh, placements=[Replicate()]) out_with_redistribute = torch.max(dt) - self.assertEqual( - out_without_redistribute.full_tensor(), out_with_redistribute.full_tensor() - ) + self.assertEqual(out_without_redistribute, out_with_redistribute) local_tensor = torch.rand(3, dtype=torch.float32, device=self.device_type) dt = DTensor.from_local( @@ -1310,9 +1296,7 @@ def test_partial_reduction_ops(self): dt = dt.redistribute(dt.device_mesh, placements=[Replicate()]) out_with_redistribute = torch.min(dt) - self.assertEqual( - out_without_redistribute.full_tensor(), out_with_redistribute.full_tensor() - ) + self.assertEqual(out_without_redistribute, out_with_redistribute) @with_comms def test_matching_partial_reduction_ops(self): @@ -1331,9 +1315,7 @@ def test_matching_partial_reduction_ops(self): self.assertTrue(out_without_redistribute.placements[0].is_partial()) self.assertTrue(out_with_redistribute.placements[0].is_replicate()) - self.assertEqual( - out_without_redistribute.full_tensor(), out_with_redistribute.full_tensor() - ) + self.assertEqual(out_without_redistribute, out_with_redistribute) @skip_if_lt_x_gpu(4) @with_comms diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index 317cdab26d319..6a1646d0f93b1 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -546,7 +546,7 @@ def test_mul_div_scalar_partial(self): self.assertTrue(res._spec.placements[0].is_partial()) res = res.redistribute(dt.device_mesh, placements=[Replicate()]) expected = sum(i for i in range(self.world_size)) * 2 - self.assertEqual(res.full_tensor(), expected) + self.assertEqual(res, expected) res = aten.div.Scalar(dt, 2) self.assertEqual( @@ -557,7 +557,7 @@ def test_mul_div_scalar_partial(self): self.assertTrue(res._spec.placements[0].is_partial()) res = res.redistribute(dt.device_mesh, placements=[Replicate()]) expected = sum(i for i in range(self.world_size)) / 2 - self.assertEqual(res.full_tensor(), expected) + self.assertEqual(res, expected) @with_comms def test_mul_div_scalar_norm_partial(self): @@ -589,7 +589,7 @@ def test_add_sub_scalar_partial(self): res = dt + 1 expected = sum(i for i in range(self.world_size)) + 1 - self.assertEqual(res.full_tensor(), expected) + self.assertEqual(res, expected) self.assertTrue(res._spec.placements[0].is_replicate()) # regular partial - scalar -> replicate @@ -601,12 +601,12 @@ def test_add_sub_scalar_partial(self): res = dt - 1 expected = sum(i for i in range(self.world_size)) - 1 - self.assertEqual(res.full_tensor(), expected) + self.assertEqual(res, expected) self.assertTrue(res._spec.placements[0].is_replicate()) res = 7 - dt expected = 7 - sum(i for i in range(self.world_size)) - self.assertEqual(res.full_tensor(), expected) + self.assertEqual(res, expected) self.assertTrue(res._spec.placements[0].is_replicate()) # regular partial + regular partial -> partial @@ -615,14 +615,14 @@ def test_add_sub_scalar_partial(self): self.assertTrue(res._spec.placements[0].is_partial()) res = res.redistribute(dt.device_mesh, placements=[Replicate()]) expected = sum(i for i in range(self.world_size)) * 2 - self.assertEqual(res.full_tensor(), expected) + self.assertEqual(res, expected) # regular partial - regular partial -> partial res = dt - dt self.assertEqual(res.to_local(), rank - rank) self.assertTrue(res._spec.placements[0].is_partial()) res = res.redistribute(dt.device_mesh, placements=[Replicate()]) - self.assertEqual(res.full_tensor(), 0) + self.assertEqual(res, 0) @with_comms def test_add_sub_scalar_norm_partial(self): @@ -636,7 +636,7 @@ def test_add_sub_scalar_norm_partial(self): self.assertTrue(isinstance(norm._spec.placements[0], _NormPartial)) norm = norm + 1 - self.assertEqual(norm.full_tensor(), 11) + self.assertEqual(norm, 11) self.assertTrue(norm._spec.placements[0].is_replicate()) dt = distribute_tensor(local_tensor, mesh, [Shard(0)]) @@ -645,7 +645,7 @@ def test_add_sub_scalar_norm_partial(self): self.assertTrue(isinstance(norm._spec.placements[0], _NormPartial)) norm = norm - 1 - self.assertEqual(norm.full_tensor(), 9) + self.assertEqual(norm, 9) self.assertTrue(norm._spec.placements[0].is_replicate()) @with_comms diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 95d500ac38df5..ca589e53c5187 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -1207,6 +1207,33 @@ def test_bucketize_sharded_boundaries(self): self.assertEqual(result.full_tensor(), expected) +class DistToCopyTest(LocalDTensorTestBase): + @with_comms + def test_to_copy_partial_reduces_for_nonlinear_cast(self): + # (reduce_op, target_dtype, expect_partial) + cases = [ + ("sum", torch.int32, False), # truncation breaks additivity + ("sum", torch.bool, False), # thresholding + ("sum", torch.float64, True), # float→float is safe + ("max", torch.int32, True), # monotonic + ("max", torch.bool, False), # thresholding + ] + with LocalTensorMode(ranks=self.world_size): + mesh = self.build_device_mesh() + input_tensor = torch.randn(4, 4, device=self.device_type) + for reduce_op, target_dtype, expect_partial in cases: + dt = DTensor.from_local(input_tensor, mesh, [Partial(reduce_op)]) + result = dt.to(target_dtype) + p = result.placements[0] + if expect_partial: + self.assertTrue(p.is_partial(), f"{reduce_op}→{target_dtype}: {p}") + self.assertEqual(p.reduce_op, reduce_op) + else: + self.assertTrue( + p.is_replicate(), f"{reduce_op}→{target_dtype}: {p}" + ) + + class DistArgMaxArgMinTest(DTensorContinuousTestBase): world_size = 4 _ops = [torch.argmax, torch.argmin] diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index f355199dce26d..396c240cdccd3 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -1216,6 +1216,73 @@ def func(a, b, c, d, *, ranks): correct = func(a, b, c, d, ranks=ranks) self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch( + { + **get_bucket_patches(), + "aten_distributed_optimizations.enable_overlap_scheduling": True, + } + ) + def test_uneven_sharding_spmd_graphs(self): + """Test that uneven DTensor sharding produces SPMD graphs across ranks. + + When a tensor dimension is not divisible by world_size, DTensor pads + before all_gather and unpads after. The no-op pad/unpad on ranks with + full-size shards must not be eliminated by remove_noop_ops, so all + ranks produce identical FX graphs with matching op counts. + """ + + def func(a, *, ranks): + # Simulate DTensor's pad-before-all_gather for uneven shards. + # rank 0: a is (4, 8), pad_size=0 → no-op pad + # rank 1: a is (3, 8), pad_size=1 → real pad to (4, 8) + full_chunk = (7 + len(ranks) - 1) // len(ranks) + pad_size = full_chunk - a.size(0) + a_padded = torch.nn.functional.pad(a, [0, 0, 0, pad_size]) + ag = _functional_collectives.all_gather_tensor(a_padded, 0, ranks) + # Unpad after all_gather: narrow to original logical size + result = ag.narrow(0, 0, 7) + return result + 1 + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + import torch.distributed as dist + from torch._subclasses.fake_tensor import unset_fake_temporarily + + world_size = self.world_size + # 7 is not divisible by 2: rank 0 gets 4 rows, rank 1 gets 3 + full_chunk = (7 + world_size - 1) // world_size + local_size = full_chunk if self.rank == 0 else 7 - full_chunk + a = torch.randn(local_size, 8, device=device_type) + ranks = list(range(world_size)) + + func_c = functools.partial(func, ranks=ranks) + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph(compiled, a) + + # Build structural fingerprint: sorted list of call_function targets. + # Node names differ across ranks, but targets and op counts must match. + targets_r = sorted( + str(n_line.split("target=")[1].split("]")[0]) + for n_line in aten_graph_str.split("\n") + if "call_function" in n_line and "target=" in n_line + ) + + with unset_fake_temporarily(): + all_targets: list[list[str] | None] = [None] * world_size + dist.all_gather_object(all_targets, targets_r) + + self.assertEqual( + all_targets[0], + all_targets[1], + "FX graph op targets differ across ranks — not SPMD. " + "No-op pad/slice may have been eliminated by remove_noop_ops.", + ) + def get_toy_model(device_type: str): """ diff --git a/test/distributed/test_cupy_as_tensor.py b/test/distributed/test_cupy_as_tensor.py index 63b290e2e8e66..06b978ca91236 100644 --- a/test/distributed/test_cupy_as_tensor.py +++ b/test/distributed/test_cupy_as_tensor.py @@ -8,10 +8,7 @@ import torch from torch.multiprocessing.reductions import reduce_tensor from torch.testing._internal.common_cuda import SM100OrLater -from torch.testing._internal.common_distributed import ( - MultiProcContinuousTest, - skip_if_rocm_multiprocess, -) +from torch.testing._internal.common_distributed import MultiProcContinuousTest from torch.testing._internal.common_utils import ( requires_cuda_p2p_access, run_tests, @@ -70,7 +67,6 @@ def _init_device(self) -> None: def device(self) -> torch.device: return torch.device(device_type, self.rank) - @skip_if_rocm_multiprocess # RuntimeError: pidfd_getfd Operation not permitted" @skip_but_pass_in_sandcastle_if( SM100OrLater, "Fails if ran in docker environment without privileged access (https://github.com/pytorch/pytorch/issues/165170)", diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index e911d34f01e23..b9eb142d83cf8 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -1902,5 +1902,39 @@ def test_remap_to_tensor(self): self.assertEqual(result7, expected7) +class ProcessGroupOpaqueTypeTest(TestCase): + """Test that ProcessGroup opaque type members are registered and exist on the class.""" + + def test_registered_members_exist_on_process_group(self): + from torch._library.opaque_object import get_member_type + + # Every member registered in _register_distributed_opaque_types() + # must actually exist on ProcessGroup. This catches renames or + # removals of C++ attributes that would cause torch.compile + # (fullgraph=True) to silently register a stale name while the + # real attribute has moved. + registered_members = [ + "size", + "rank", + "_get_backend_name", + "group_name", + "group_desc", + "__eq__", + ] + for member_name in registered_members: + self.assertIsNotNone( + get_member_type(ProcessGroup, member_name), + f"'{member_name}' is not registered as a ProcessGroup opaque " + f"type member. Add it to _register_distributed_opaque_types() " + f"in torch/distributed/device_mesh.py", + ) + self.assertTrue( + hasattr(ProcessGroup, member_name), + f"'{member_name}' is registered as a ProcessGroup opaque type " + f"member but does not exist on the ProcessGroup class. " + f"Was it renamed or removed?", + ) + + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 3129ddf7ccd11..d9d0dd4f80508 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -1,8 +1,10 @@ # Owner(s): ["module: dynamo"] +import contextlib import copy import dataclasses import functools +import multiprocessing import operator import os import pickle @@ -20,6 +22,7 @@ from torch._functorch import config as functorch_config from torch._functorch._aot_autograd.autograd_cache import ( AOTAutogradCache, + AOTAutogradCachePickler, autograd_cache_key, BypassAOTAutogradCache, sanitize_gm_for_cache, @@ -73,6 +76,20 @@ def uuid(self): custom_pre_grad_pass_remove_ident_muls = CustomPreGradPassRemoveIdentMuls() +@contextlib.contextmanager +def _fake_process_group(rank=0, world_size=2): + """Context manager for setting up and tearing down a fake process group.""" + import torch.distributed as c10d + from torch.testing._internal.distributed.fake_pg import FakeStore + + fake_store = FakeStore() + c10d.init_process_group("fake", store=fake_store, rank=rank, world_size=world_size) + try: + yield + finally: + c10d.destroy_process_group() + + def aot_eager_regional_inductor(): """ Regional inductor backend for AOT autograd. @@ -3073,6 +3090,45 @@ def run_script(pass_uuid): self.assertEqual(c3.get("autograd_cache_miss", 0), 1) self.assertEqual(c3.get("autograd_cache_hit", 0), 0) + @unittest.skipIf(not torch.distributed.is_available(), "requires distributed") + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_dtensor_cache_hit(self): + """ + Test that DTensor produces cache hits on second compile. + + This follows the standard AOTAutograd cache test pattern: compile once + (cache miss), reset dynamo, compile again with equivalent input (cache hit). + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor import DTensor, Replicate + + with _fake_process_group(): + mesh = init_device_mesh("cpu", (2,)) + + def fn(x): + return x.sin() + + compiled_fn = torch.compile(fn, backend="inductor") + + # First call - cache miss + dtensor1 = DTensor.from_local(torch.zeros(4, 4), mesh, [Replicate()]) + compiled_fn(dtensor1) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + + # Reset dynamo but keep the cache + self._clear_dynamo_and_codecache() + + # Second call with equivalent DTensor - should hit cache + dtensor2 = DTensor.from_local(torch.zeros(4, 4), mesh, [Replicate()]) + compiled_fn(dtensor2) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + @functorch_config.patch({"bundled_autograd_cache": True}) class AOTAutogradCacheBundledTests(AOTAutogradCacheTests): @@ -3461,6 +3517,176 @@ def test_pickle_entry_strict_mode_raises(self): ): AOTAutogradCache._pickle_entry(entry, remote=False) + def test_nested_tensor_subclass_cache_key(self): + ctx = multiprocessing.get_context("spawn") + results = [] + for _ in range(2): + queue = ctx.Queue() + p = ctx.Process( + target=_subprocess_gen_nested_subclass_cache_key, + args=(queue,), + ) + p.start() + p.join() + + if p.exitcode != 0: + self.skipTest(f"Subprocess exited with code {p.exitcode}") + + results.append(queue.get()) + + # Cache keys from two different processes should be identical + self.assertEqual( + results[0], + results[1], + "Nested tensor subclass cache keys should be identical across processes", + ) + + @unittest.skipIf(not torch.distributed.is_available(), "requires distributed") + def test_dtensor_different_process_cache_key(self): + """ + Test that DTensor cache keys are consistent across different processes. + + This is a critical test for warm start cache hits. The cache key must be + deterministic and not depend on process-specific values like memory addresses + or object ids that would differ between processes. + """ + ctx = multiprocessing.get_context("spawn") + results = [] + for _ in range(2): + queue = ctx.Queue() + p = ctx.Process( + target=_subprocess_gen_dtensor_cache_key, + args=(queue,), + ) + p.start() + p.join() + + if p.exitcode != 0: + self.skipTest(f"Subprocess exited with code {p.exitcode}") + + results.append(queue.get()) + + # Cache keys from two different processes should be identical + self.assertEqual( + results[0], + results[1], + "DTensor cache keys should be identical across processes", + ) + + @unittest.skipIf(not torch.distributed.is_available(), "requires distributed") + def test_dtensor_different_placements_different_cache_key(self): + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor import DTensor, Replicate, Shard + + with _fake_process_group(): + mesh = init_device_mesh("cpu", (2,)) + + local_tensor = torch.zeros(4, 4) + dtensor_replicate = DTensor.from_local(local_tensor, mesh, [Replicate()]) + dtensor_shard = DTensor.from_local(local_tensor, mesh, [Shard(0)]) + + def fn(x): + return x.sin() + + # Get dynamo output for replicate + torch._dynamo.reset() + fx_graph = None + + def compiler(gm, inputs, **kwargs): + nonlocal fx_graph + fx_graph = gm + return gm + + g = torch.compile(fn, backend=compiler, fullgraph=True) + g(dtensor_replicate) + + pickler = AOTAutogradCachePickler(fx_graph) + cache_key_replicate = pickler.get_hash(dtensor_replicate) + cache_key_shard = pickler.get_hash(dtensor_shard) + + # Different placements should produce different cache keys + self.assertNotEqual( + cache_key_replicate, + cache_key_shard, + "DTensor with Replicate() should have different cache key than Shard(0)", + ) + + +def _subprocess_gen_dtensor_cache_key(queue): + """ + Subprocess helper to generate a DTensor cache key. + Must be at module level for multiprocessing to work. + """ + import torch + import torch._dynamo + import torch.distributed as c10d + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCachePickler + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor import DTensor, Replicate + from torch.testing._internal.distributed.fake_pg import FakeStore + + fake_store = FakeStore() + c10d.init_process_group("fake", store=fake_store, rank=0, world_size=2) + try: + mesh = init_device_mesh("cpu", (2,)) + + # Create DTensor + local_tensor = torch.zeros(4, 4) + dtensor = DTensor.from_local(local_tensor, mesh, [Replicate()]) + + def fn(x): + return x.sin() + + # Get dynamo output + torch._dynamo.reset() + fx_graph = None + + def compiler(gm, inputs, **kwargs): + nonlocal fx_graph + fx_graph = gm + return gm + + g = torch.compile(fn, backend=compiler, fullgraph=True) + g(dtensor) + + pickler = AOTAutogradCachePickler(fx_graph) + cache_key = pickler.get_hash(dtensor) + + queue.put(cache_key) + finally: + c10d.destroy_process_group() + + +def _subprocess_gen_nested_subclass_cache_key(queue): + import torch + import torch._dynamo + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCachePickler + from torch.testing._internal.two_tensor import TwoTensor + + inner1 = TwoTensor(torch.zeros(4, 4), torch.zeros(4, 4)) + inner2 = TwoTensor(torch.zeros(4, 4), torch.zeros(4, 4)) + nested_two_tensor = TwoTensor(inner1, inner2) + + def fn(x): + return x.sin() + + # Get dynamo output + torch._dynamo.reset() + fx_graph = None + + def compiler(gm, inputs, **kwargs): + nonlocal fx_graph + fx_graph = gm + return gm + + g = torch.compile(fn, backend=compiler, fullgraph=True) + g(nested_two_tensor) + + pickler = AOTAutogradCachePickler(fx_graph) + cache_key = pickler.get_hash(nested_two_tensor) + + queue.put(cache_key) + def _policy_save_mm(ctx, op, *args, **kwargs): if op == torch.ops.aten.mm.default: diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index f65d860fab64a..f03759df31f93 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -2,15 +2,24 @@ import contextlib import sys import unittest +from collections import defaultdict from contextlib import contextmanager import torch import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same +from torch._dynamo.testing import ( + check_dynamic_shape_capture, + EagerAndRecordGraphs, + normalize_gm, + same, +) from torch._dynamo.utils import counters from torch.nn import functional as F -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, + TEST_MULTIGPU, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -583,6 +592,39 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + @unittest.skipIf(not TEST_MULTIGPU, "Requires multiple gpus") + def test_cuda__exchange_device(self): + def fn(x): + dev = torch.cuda._exchange_device(0) + x = torch.sin(x + dev) + torch.cuda._maybe_exchange_device(dev) + return x + + initial_dev = torch.cuda.current_device() + x = torch.randn((2, 2), device="cuda") + ref = fn(x) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + res = opt_fn(x) + self.assertEqual(ref, res) + + # make sure we recompile if device changes + with torch.cuda.device(1): + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + self.assertEqual(torch.cuda.current_device(), initial_dev) + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_cuda__exchange_device_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(args, kwargs): + torch.cuda._exchange_device(*args, **kwargs) + + initial_dev = torch.cuda.current_device() + for args, kwargs in (((), ()), ((0, 0), ()), ((), ("kwarg",))): + self.assertRaises(torch._dynamo.exc.Unsupported, fn, args, kwargs) + self.assertEqual(torch.cuda.current_device(), initial_dev) + def test_autograd_profiler_enabled(self): def fn(x): if torch.autograd._profiler_enabled(): @@ -1012,6 +1054,348 @@ def fn(a, b): self.assertTrue(res[0].dtype == torch.float16) self.assertTrue(res[1].dtype == torch.float16) + def test__enter__exit_autocast(self): + def f(x, y): + m = torch.amp.autocast_mode._enter_autocast("cpu") + x = x @ y + torch.amp.autocast_mode._exit_autocast(m) + return x + + eager = EagerAndRecordGraphs() + opt_f = torch.compile(f, backend=eager, fullgraph=True) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + z = f(x, y) + opt_z = opt_f(x, y) + self.assertEqual(z, opt_z) + self.assertEqual(z.dtype, opt_z.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + + if check_dynamic_shape_capture(): + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]", L_y_: "f32[s77, s77]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[s77, s77]" = l_x_ @ l_y_; l_x_ = l_y_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled_1 = torch.set_autocast_enabled('cpu', False); set_autocast_enabled_1 = None + + set_autocast_dtype_1 = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype_1 = None + + set_autocast_cache_enabled_1 = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled_1 = None + return (x,) +""", # NOQA: B950 + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled_1 = torch.set_autocast_enabled('cpu', False); set_autocast_enabled_1 = None + + set_autocast_dtype_1 = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype_1 = None + + set_autocast_cache_enabled_1 = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled_1 = None + return (x,) +""", # NOQA: B950 + ) + + def test__enter__exit_autocast_graph_break(self): + def f(x, y, z): + m = torch.amp.autocast_mode._enter_autocast("cpu") + x = x @ y + torch._dynamo.graph_break() + x = x @ z + # At this point m is wrapped as an AutocastModeVariable, which will graph break on the __exit__ call + torch.amp.autocast_mode._exit_autocast(m) + return x + + eager = EagerAndRecordGraphs() + opt_f = torch.compile(f, backend=eager, fullgraph=False) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + z = torch.randn(3, 3, dtype=torch.float32) + out = f(x, y, z) + opt_out = opt_f(x, y, z) + self.assertEqual(out, opt_out) + self.assertEqual(out.dtype, opt_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + + if check_dynamic_shape_capture(): + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]", L_y_: "f32[s77, s77]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[s77, s77]" = l_x_ @ l_y_; l_x_ = l_y_ = None + return (x,) +""", # NOQA: B950 + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None + return (x,) +""", # NOQA: B950 + ) + + # Doesn't include autocast functions, see comment above + graph = eager.graphs[1] + actual = normalize_gm(graph.print_readable(False)) + + if check_dynamic_shape_capture(): + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s77: "Sym(s77)", L_x_: "bf16[s77, s77]", L_z_: "f32[s77, s77]"): + l_x_ = L_x_ + l_z_ = L_z_ + + x: "bf16[s77, s77]" = l_x_ @ l_z_; l_x_ = l_z_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', False); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + return (x,) +""", # NOQA: B950 + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "bf16[3, 3]", L_z_: "f32[3, 3]"): + l_x_ = L_x_ + l_z_ = L_z_ + + x: "bf16[3, 3]" = l_x_ @ l_z_; l_x_ = l_z_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', False); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + return (x,) +""", # NOQA: B950 + ) + + def test_autocast_low_level_api(self): + def f(x, y): + torch.set_autocast_enabled("cpu", True) + torch.set_autocast_dtype("cpu", torch.bfloat16) + torch.set_autocast_cache_enabled(True) + x = x @ y + torch.autocast_decrement_nesting() + torch.clear_autocast_cache() + torch.set_autocast_enabled("cpu", False) + return x + + prev_enabled = torch.is_autocast_enabled("cpu") + prev_dtype = torch.get_autocast_dtype("cpu") + prev_cache = torch.is_autocast_cache_enabled() + + try: + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + out = f(x, y) + opt_out = opt_f(x, y) + self.assertEqual(out, opt_out) + self.assertEqual(out.dtype, opt_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + finally: + torch.set_autocast_enabled("cpu", prev_enabled) + torch.set_autocast_dtype("cpu", prev_dtype) + torch.set_autocast_cache_enabled(prev_cache) + + def test__enter__exit_autocast_function_mode(self): + class FunctionCount(torch.overrides.TorchFunctionMode): + def __init__(self): + self.counts = defaultdict(int) + + def __torch_function__(self, func, types, args, kwargs=None): + self.counts[func] += 1 + return func(*args, **(kwargs or {})) + + def f(x, y): + m = torch.amp.autocast_mode._enter_autocast("cpu") + x = x @ y + torch.amp.autocast_mode._exit_autocast(m) + return x + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + with FunctionCount() as fc: + z = f(x, y) + self.assertEqual(fc.counts[torch.amp.autocast_mode._enter_autocast], 1) + self.assertEqual(fc.counts[torch.amp.autocast_mode._exit_autocast], 1) + with FunctionCount() as fc: + opt_z = opt_f(x, y) + self.assertEqual(fc.counts[torch.amp.autocast_mode._enter_autocast], 1) + self.assertEqual(fc.counts[torch.amp.autocast_mode._exit_autocast], 1) + self.assertEqual(z, opt_z) + self.assertEqual(z.dtype, opt_z.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + + def test__enter__exit_autocast_non_idempotent(self): + # Recompile trick doesn't work with dynamic shapes + if check_dynamic_shape_capture(): + return + + def f(x, y): + with torch.amp.autocast("cpu"): + x = x @ y + return x + + eager = EagerAndRecordGraphs() + opt_f = torch.compile(f, backend=eager, fullgraph=False) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + out = f(x, y) + opt_out = opt_f(x, y) + self.assertEqual(out, opt_out) + self.assertEqual(out.dtype, opt_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) + + x: "bf16[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None + + _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None + return (x,) +""", # NOQA: B950 + ) + + # Recompiling will decompose the _enter_autocast and _exit_autocast calls to lower level autocast functions + eager = EagerAndRecordGraphs() + d = {} + exec(actual, globals(), d) + retraced = torch.compile(d["GraphModule"], backend=eager, fullgraph=True) + retraced_out = retraced()(x, y)[0] + self.assertEqual(out, retraced_out) + self.assertEqual(out.dtype, retraced_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_L_x_: "f32[3, 3]", L_L_y_: "f32[3, 3]"): + l_l_x_ = L_L_x_ + l_l_y_ = L_L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + x: "bf16[3, 3]" = l_l_x_ @ l_l_y_; l_l_x_ = l_l_y_ = None + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled_1 = torch.set_autocast_enabled('cpu', False); set_autocast_enabled_1 = None + + set_autocast_dtype_1 = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype_1 = None + + set_autocast_cache_enabled_1 = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled_1 = None + return (x,) +""", # NOQA: B950 + ) + @parametrize( "Ctx", [CustomizedCtxManagerWithGraphBreak, customized_ctx_manager_with_graph_break], @@ -1448,6 +1832,36 @@ def forward(self, L_y_: "f32[]"): """, # NOQA: B950 ) + def test__saved_tensors_hooks_disable(self): + def fn(x): + y = x + 1 + torch._C._autograd._saved_tensors_hooks_disable("This is not supported") + y *= 2 + torch._C._autograd._saved_tensors_hooks_enable() + return y + + eager = EagerAndRecordGraphs() + torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(())) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[]"): + l_x_ = L_x_ + + y: "f32[]" = l_x_ + 1; l_x_ = None + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None + + y *= 2; y_1: "f32[]" = y; y = None + + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None + return (y_1,) +""", # NOQA: B950 + ) + def test_context_wrapping_grad_mode_decorator(self): ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)] for call in [True, False]: diff --git a/test/dynamo/test_guard_exclusion.py b/test/dynamo/test_guard_exclusion.py new file mode 100644 index 0000000000000..cfa84c7e302d2 --- /dev/null +++ b/test/dynamo/test_guard_exclusion.py @@ -0,0 +1,566 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch._dynamo +import torch._dynamo.testing +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +class GraphTracker: + """Backend that tracks which compiled graph (by compilation order) handles each call.""" + + def __init__(self): + self.graphs = [] + self.call_log = [] + + def __call__(self, gm, example_inputs): + graph_id = len(self.graphs) + self.graphs.append(gm) + + def wrapper(*args, **kwargs): + self.call_log.append(graph_id) + return gm.forward(*args, **kwargs) + + return wrapper + + @property + def frame_count(self): + return len(self.graphs) + + def reset(self): + self.graphs.clear() + self.call_log.clear() + + +@skipIfTorchDynamo("uses custom backend incompatible with PYTORCH_TEST_WITH_DYNAMO") +@torch._dynamo.config.patch(automatic_dynamic_exclusion_guard=True) +class TestGuardExclusion(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + def tearDown(self): + super().tearDown() + torch._dynamo.reset() + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_automatic_dynamic_exclusive_guard_basic(self): + """ + 1. [3, 4] -> Graph 0 (static) + 2. [5, 4] -> Graph 1 (dim 0 dynamic), exclusion rejects dim0==3 + 3. [7, 4] -> Graph 1 (reuse dynamic graph) + 4. [3, 4] -> Graph 0 (exclusion triggers, reverts to static) + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + x1 = torch.randn(3, 4) + result1 = opt(x1) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + x2 = torch.randn(5, 4) + result2 = opt(x2) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # dynamic graph reuse + opt(torch.randn(7, 4)) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # original shape reverts to Graph 0 + x3 = torch.randn(3, 4) + result3 = opt(x3) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 0) + + self.assertEqual(result1, x1 * 2) + self.assertEqual(result2, x2 * 2) + self.assertEqual(result3, x3 * 2) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_accumulated_exclusion_does_not_shadow_intermediate_graph(self): + """ + Tensor accumulation: dims become dynamic one at a time. + 1. func(3, 4) -> Graph 0: static (3, 4) + 2. func(5, 4) -> Graph 1: (s0, 4), excluded dim0=3 + 3. func(3, 19) -> Graph 2: (s0, s1), excluded dim1=4 + (dim0's exclusion is cleared since no dim transitioned) + 4. func(5, 4) -> should use Graph 1, not Graph 2 + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Call 1: shape [3, 4] -> compiles Graph 0 (static) + opt(torch.randn(3, 4)) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + # Call 2: shape [5, 4] -> compiles Graph 1 (s0, 4) + opt(torch.randn(5, 4)) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # Call 3: shape [3, 19] -> Graph 1 exclusion rejects (size(0)==3), + # Graph 0 rejects (19!=4), recompiles Graph 2 (s0, s1) + opt(torch.randn(3, 19)) + self.assertEqual(tracker.frame_count, 3) + self.assertEqual(tracker.call_log[-1], 2) + + # Call 4: shape [5, 4] -> should still use Graph 1 (s0, 4). + # Graph 2's exclusion is dim1=4, so (5, 4) is rejected and + # falls through to Graph 1. + opt(torch.randn(5, 4)) + + self.assertEqual( + tracker.call_log[-1], + 1, + "Input [5,4] should use Graph 1 (s0, 4), not Graph 2 (s0, s1). " + "Graph 2's exclusion must reject size(1)==4 independently, not " + "require size(0)==3 AND size(1)==4.", + ) + + # Call 5: shape [3, 4] -> should still use Graph 0 (static) + opt(torch.randn(3, 4)) + self.assertEqual( + tracker.call_log[-1], + 0, + "Input [3,4] should use Graph 0 (static)", + ) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_4d_progressive_dynamism_cascading(self): + """ + 4D tensor where dims become dynamic one at a time across recompilations. + Each new graph is more general, and exclusion guards ensure inputs cascade + to the most specialized graph. + + Graph 0: (2, 3, 4, 5) static + Graph 1: (dyn, 3, 4, 5) excluded dim0=2 + Graph 2: (dyn, dyn, 4, 5) excluded dim1=3 + Graph 3: (dyn, dyn, dyn, 5) excluded dim2=4 + """ + + def foo(x): + return x.sum() + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Graph 0: (2, 3, 4, 5) static + opt(torch.randn(2, 3, 4, 5)) + self.assertEqual(tracker.frame_count, 1) + + # Graph 1: dim0 changes -> (dyn, 3, 4, 5) + opt(torch.randn(7, 3, 4, 5)) + self.assertEqual(tracker.frame_count, 2) + + # Graph 2: dim1 also changes -> (dyn, dyn, 4, 5) + # Input (7, 8, 4, 5): Graph 1 rejects dim1=8≠3, Graph 0 rejects dim0=7≠2 + opt(torch.randn(7, 8, 4, 5)) + self.assertEqual(tracker.frame_count, 3) + + # Graph 3: dim2 also changes -> (dyn, dyn, dyn, 5) + # Input (7, 8, 9, 5): Graph 2 rejects dim2=9≠4 + opt(torch.randn(7, 8, 9, 5)) + self.assertEqual(tracker.frame_count, 4) + + # Now verify cascading: each input routes to the most specialized graph. + # (2, 3, 4, 5) -> Graph 0 (static, most specialized) + opt(torch.randn(2, 3, 4, 5)) + self.assertEqual(tracker.call_log[-1], 0, "(2,3,4,5) -> Graph 0 (static)") + + # (7, 3, 4, 5) -> Graph 1 (dyn, 3, 4, 5) + opt(torch.randn(7, 3, 4, 5)) + self.assertEqual(tracker.call_log[-1], 1, "(7,3,4,5) -> Graph 1 (dyn,3,4,5)") + + # (7, 8, 4, 5) -> Graph 2 (dyn, dyn, 4, 5) + opt(torch.randn(7, 8, 4, 5)) + self.assertEqual(tracker.call_log[-1], 2, "(7,8,4,5) -> Graph 2 (dyn,dyn,4,5)") + + # (7, 8, 9, 5) -> Graph 3 (dyn, dyn, dyn, 5) + opt(torch.randn(7, 8, 9, 5)) + self.assertEqual( + tracker.call_log[-1], 3, "(7,8,9,5) -> Graph 3 (dyn,dyn,dyn,5)" + ) + + # (20, 30, 40, 5) -> Graph 3 (most general, no exclusion hit) + opt(torch.randn(20, 30, 40, 5)) + self.assertEqual( + tracker.call_log[-1], 3, "(20,30,40,5) -> Graph 3 (most general)" + ) + + self.assertEqual(tracker.frame_count, 4, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_5d_two_rounds_of_dynamism(self): + """ + 5D tensor with two rounds of automatic_dynamic. Verify inputs route + to the most specialized graph after each round. + + Graph 0: (2, 3, 4, 5, 6) static + Graph 1: (dyn, 3, 4, 5, 6) excluded dim0=2 + Graph 2: (dyn, 3, dyn, 5, 6) excluded dim2=4 + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Graph 0: static + opt(torch.randn(2, 3, 4, 5, 6)) + self.assertEqual(tracker.frame_count, 1) + + # Graph 1: dim0 becomes dynamic + opt(torch.randn(8, 3, 4, 5, 6)) + self.assertEqual(tracker.frame_count, 2) + + # Graph 2: dim2 also becomes dynamic + opt(torch.randn(8, 3, 10, 5, 6)) + self.assertEqual(tracker.frame_count, 3) + + # Verify routing: + # Original static shape -> Graph 0 + opt(torch.randn(2, 3, 4, 5, 6)) + self.assertEqual(tracker.call_log[-1], 0, "Original -> Graph 0") + + # dim0 differs, dim2 matches static -> Graph 1 + opt(torch.randn(9, 3, 4, 5, 6)) + self.assertEqual(tracker.call_log[-1], 1, "dim0 changed -> Graph 1") + + # dim0 differs, dim2 differs -> Graph 2 + opt(torch.randn(9, 3, 11, 5, 6)) + self.assertEqual(tracker.call_log[-1], 2, "dim0+dim2 changed -> Graph 2") + + # dim0 is original excluded value, dim2 differs -> Graph 2 should still + # accept because dim0's exclusion is None (already dynamic when snapshot taken) + opt(torch.randn(2, 3, 11, 5, 6)) + self.assertEqual( + tracker.call_log[-1], + 2, + "dim0=2 with dim2≠4 -> Graph 2 (dim0 exclusion is None)", + ) + + self.assertEqual(tracker.frame_count, 3, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_many_entries_wrong_graph_selection(self): + """ + Convoluted scenario: 4D tensor, three rounds of dynamism creating 4 graphs. + Without exclusion guards, the most general graph would shadow all others. + Test that each input gets the best (most specialized) match. + + Graph 0: (2, 3, 4, 5) static + Graph 1: (dyn, 3, 4, 5) excluded dim0=2 + After Graph 1, (2, 8, 4, 5) triggers dim1 dynamic: + Graph 2: (dyn, dyn, 4, 5) excluded dim1=3 + After Graph 2, (2, 8, 9, 5) triggers dim2 dynamic: + Graph 3: (dyn, dyn, dyn, 5) excluded dim2=4 + + Key: Graph 3 should NOT steal inputs that belong to Graph 0, 1, or 2. + """ + + def foo(x): + return x.relu() + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Build up 4 graphs progressively + opt(torch.randn(2, 3, 4, 5)) # Graph 0 + opt(torch.randn(7, 3, 4, 5)) # Graph 1: dim0 dynamic + opt(torch.randn(7, 8, 4, 5)) # Graph 2: dim1 also dynamic + opt(torch.randn(7, 8, 9, 5)) # Graph 3: dim2 also dynamic + self.assertEqual(tracker.frame_count, 4) + + # Now stress-test routing with various inputs: + test_cases = [ + # (shape, expected_graph, description) + ((2, 3, 4, 5), 0, "exact original -> static Graph 0"), + ((7, 3, 4, 5), 1, "dim0 differs -> Graph 1"), + ((99, 3, 4, 5), 1, "dim0 differs (large) -> Graph 1"), + ((7, 8, 4, 5), 2, "dim0+dim1 differ -> Graph 2"), + ((7, 99, 4, 5), 2, "dim0+dim1 differ (large) -> Graph 2"), + ((7, 8, 9, 5), 3, "dim0+dim1+dim2 differ -> Graph 3"), + ((99, 99, 99, 5), 3, "all non-static dims differ -> Graph 3"), + ] + + for shape, expected_graph, desc in test_cases: + opt(torch.randn(*shape)) + self.assertEqual(tracker.call_log[-1], expected_graph, desc) + + self.assertEqual(tracker.frame_count, 4, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_multi_dim_dynamic_and_semantics(self): + """ + When multiple dims become dynamic at once, AND semantics is critical. + Graph 0: (4, 3, 234, 5) static + Graph 1: (s0, 3, s2, s3) dynamic on dims 0,2,3. excluded=(4, _, 234, 5) + + OR semantics (wrong): rejects (4, 3, 100, 20) because dim0==4 matches. + AND semantics (correct): accepts (4, 3, 100, 20) because not ALL excluded + dims match (dim2=100≠234 and dim3=20≠5). + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + opt(torch.randn(4, 3, 234, 5)) # Graph 0: static + opt(torch.randn(10, 3, 100, 20)) # Graph 1: dims 0,2,3 dynamic + self.assertEqual(tracker.frame_count, 2) + + # Only the exact original shape should be excluded from Graph 1 + opt(torch.randn(4, 3, 234, 5)) + self.assertEqual(tracker.call_log[-1], 0, "Exact original -> Graph 0") + + # Partial matches should NOT be excluded (AND semantics) + opt(torch.randn(4, 3, 100, 20)) # dim0=4 matches, dims 2,3 don't + self.assertEqual(tracker.call_log[-1], 1, "dim0=4 partial match -> Graph 1") + + opt(torch.randn(10, 3, 234, 20)) # dim2=234 matches, dims 0,3 don't + self.assertEqual(tracker.call_log[-1], 1, "dim2=234 partial match -> Graph 1") + + opt(torch.randn(10, 3, 234, 5)) # dim2=234, dim3=5 match, dim0 doesn't + self.assertEqual(tracker.call_log[-1], 1, "dim2+dim3 partial match -> Graph 1") + + opt(torch.randn(4, 3, 234, 20)) # dim0=4, dim2=234 match, dim3 doesn't + self.assertEqual(tracker.call_log[-1], 1, "dim0+dim2 partial match -> Graph 1") + + # Totally new shape, no exclusion hit + opt(torch.randn(99, 3, 88, 77)) + self.assertEqual(tracker.call_log[-1], 1, "New shape -> Graph 1") + + self.assertEqual(tracker.frame_count, 2, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_integer_input_exclusion_basic(self): + """ + Integer inputs that become dynamic should also get exclusion guards. + 1. foo(x, 3) -> Graph 0: static n=3 + 2. foo(x, 5) -> Graph 1: dynamic n, excluded should reject n==3 + 3. foo(x, 3) -> should use Graph 0 (static), not Graph 1 + """ + + def foo(x, n): + return x * n + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + x = torch.randn(4) + + opt(x, 3) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + opt(x, 5) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + opt(x, 3) + self.assertEqual( + tracker.call_log[-1], + 0, + "Input n=3 should use Graph 0 (static), not Graph 1 (dynamic n).", + ) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_integer_input_exclusion_accumulation(self): + """ + Same accumulation scenario as the tensor test but with integer inputs. + 1. foo(x, 3, 4) -> Graph 0: static (3, 4) + 2. foo(x, 5, 4) -> Graph 1: dynamic (s0, 4), exclusion rejects n0==3 + 3. foo(x, 3, 19) -> Graph 2: dynamic (s0, s1), exclusion should reject + n1==4 independently, not require n0==3 AND n1==4 + 4. foo(x, 5, 4) -> should use Graph 1, not Graph 2 + """ + + def foo(x, n, m): + return x * n + m + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + x = torch.randn(4) + + opt(x, 3, 4) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + opt(x, 5, 4) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + opt(x, 3, 19) + self.assertEqual(tracker.frame_count, 3) + self.assertEqual(tracker.call_log[-1], 2) + + opt(x, 5, 4) + self.assertEqual( + tracker.call_log[-1], + 1, + "Input (5, 4) should use Graph 1 (s0, 4), not Graph 2 (s0, s1).", + ) + + opt(x, 3, 4) + self.assertEqual( + tracker.call_log[-1], + 0, + "Input (3, 4) should use Graph 0 (static)", + ) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_two_tensor_inputs_exclusion(self): + """ + Multi-tensor exclusion: all pairs are flattened and guarded with + not-all semantics. The guard rejects only when ALL excluded values + across ALL tensors match simultaneously. + + 1. foo(x=[3,4], y=[5,6]) -> Graph 0: all static + 2. foo(x=[3,10], y=[5,11]) -> Graph 1: x.dim1, y.dim1 dynamic + exclusion: Or(x.dim1!=4, y.dim1!=6) + 3. foo(x=[3,10], y=[5,6]) -> Graph 1 (not all match -> passes) + 4. foo(x=[3,4], y=[5,21]) -> Graph 1 (not all match -> passes) + 5. foo(x=[3,4], y=[5,6]) -> Graph 0 (all match -> rejected) + 6. foo(x=[3,10], y=[5,11]) -> Graph 1 + """ + + def foo(x, y): + return x.sum() + y.sum() + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + opt(torch.randn(3, 4), torch.randn(5, 6)) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + opt(torch.randn(3, 10), torch.randn(5, 11)) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # Only y.dim1 matches excluded value; combined Or guard passes. + opt(torch.randn(3, 10), torch.randn(5, 6)) + self.assertEqual(tracker.frame_count, 2, "Should not recompile") + self.assertEqual(tracker.call_log[-1], 1) + + # Only x.dim1 matches excluded value; combined Or guard passes. + opt(torch.randn(3, 4), torch.randn(5, 21)) + self.assertEqual(tracker.frame_count, 2, "Should not recompile") + self.assertEqual(tracker.call_log[-1], 1) + + # Both match excluded values; Or guard fails -> falls to Graph 0. + opt(torch.randn(3, 4), torch.randn(5, 6)) + self.assertEqual(tracker.call_log[-1], 0) + + # Neither matches; Or guard passes -> Graph 1. + opt(torch.randn(3, 10), torch.randn(5, 11)) + self.assertEqual(tracker.call_log[-1], 1) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_multi_tensor_and_scalar_accumulation(self): + """ + 3-dim tensors + 3 scalar inputs with cascading accumulation. + Each step transitions one input while the rest stay the same, + verifying that only the current transition's exclusion is emitted. + + Graph 0: all static + Graph 1: x.dim2, y.dim2 dynamic excl: Or(Ne(x.dim2,4), Ne(y.dim2,7)) + Graph 2: + n dynamic excl: Ne(n, 2) + Graph 3: + m dynamic excl: Ne(m, 3) + Graph 4: + k dynamic excl: Ne(k, 4) + """ + + def foo(x, y, n, m, k): + return x.sum() * n + y.sum() * m + k + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # -- Compilation steps -- + + # Graph 0: all static + opt(torch.randn(2, 3, 4), torch.randn(5, 6, 7), 2, 3, 4) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + # Graph 1: tensor dim2 changes + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 2, 3, 4) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # Graph 2: scalar n changes (tensor exclusions cleared) + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 3, 4) + self.assertEqual(tracker.frame_count, 3) + self.assertEqual(tracker.call_log[-1], 2) + + # Graph 3: scalar m changes (n exclusion cleared) + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 9, 4) + self.assertEqual(tracker.frame_count, 4) + self.assertEqual(tracker.call_log[-1], 3) + + # Graph 4: scalar k changes (m exclusion cleared) + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 9, 15) + self.assertEqual(tracker.frame_count, 5) + self.assertEqual(tracker.call_log[-1], 4) + + # -- Verification: each input routes to the correct graph -- + + # k=4 triggers Graph 4 exclusion -> Graph 3 + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 9, 4) + self.assertEqual(tracker.call_log[-1], 3, "k=4 should fall to Graph 3") + + # m=3 also triggers Graph 3 exclusion -> Graph 2 + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 3, 4) + self.assertEqual(tracker.call_log[-1], 2, "m=3 should fall to Graph 2") + + # n=2 also triggers Graph 2 exclusion -> Graph 1 + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 2, 3, 4) + self.assertEqual(tracker.call_log[-1], 1, "n=2 should fall to Graph 1") + + # tensor dims match original -> Graph 1 exclusion triggers -> Graph 0 + opt(torch.randn(2, 3, 4), torch.randn(5, 6, 7), 2, 3, 4) + self.assertEqual(tracker.call_log[-1], 0, "Original sizes should use Graph 0") + + # mixed: new tensor dims + new scalars -> Graph 4 + opt(torch.randn(2, 3, 20), torch.randn(5, 6, 21), 50, 60, 70) + self.assertEqual(tracker.frame_count, 5, "Should not recompile") + self.assertEqual(tracker.call_log[-1], 4, "All-new values should use Graph 4") + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index c404432f430f7..c118bfc4ae87d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -4363,6 +4363,74 @@ def fn(m, x): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 4) + def test_id_of_container_as_dict_key(self): + MY_DICT = {"a": 1, "b": 2} + + def fn(x): + memo = {} + memo[id(MY_DICT)] = True + if id(MY_DICT) in memo: + return x + 1.0 + return x + 2.0 + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + + def test_id_of_list_as_dict_key(self): + MY_LIST = [1.0, 2.0] + + def fn(x): + memo = {} + memo[id(MY_LIST)] = True + if id(MY_LIST) in memo: + return x + 1.0 + return x + 2.0 + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + + def test_deepcopy_dict(self): + MY_DICT = {"a": 1, "b": 2.0, "c": None} + + def fn(x): + d = copy.deepcopy(MY_DICT) + d["b"] = 3.0 + return x + d["b"] + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + + def test_deepcopy_nested_dict(self): + NESTED = {"a": {"b": 1.0}, "c": [2.0, 3.0]} + + def fn(x): + d = copy.deepcopy(NESTED) + return x + d["a"]["b"] + d["c"][0] + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + + def test_deepcopy_list(self): + MY_LIST = [1.0, 2.0, 3.0] + + def fn(x): + lst = copy.deepcopy(MY_LIST) + lst[0] = 5.0 + return x + lst[0] + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + def test_global_state_guard_serialization(self): GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard guards = GlobalStateGuard() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 679f160590485..1bb76c7333838 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3377,6 +3377,35 @@ def forward(self, x): self.assertEqual(cnt.frame_count, 2) self.assertIsNotNone(model.linear.weight.grad) + @torch._dynamo.config.patch(skip_tensor_guards_with_matching_dict_tags=True) + @torch._dynamo.config.patch("use_recursive_dict_tags_for_guards", True) + def test_param_dtype_change_recompiles_with_recursive_dict_tags(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4)) + + def forward(self, x): + return x * self.scale + + model = MyModule() + x = torch.randn(4) + + cnt = torch._dynamo.testing.CompileCounter() + compiled = torch.compile(model, backend=cnt, fullgraph=True) + + self.assertTrue(torch._dynamo.testing.same(model(x), compiled(x))) + self.assertEqual(cnt.frame_count, 1) + + model.to(dtype=torch.float64) + + recompiled = torch.compile(model, backend=cnt, fullgraph=True) + result = recompiled(x) + + self.assertEqual(result.dtype, torch.float64) + self.assertTrue(torch._dynamo.testing.same(model(x), result)) + self.assertEqual(cnt.frame_count, 2) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_param_requires_grad_submodule(self): class Inner(torch.nn.Module): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index b3bb2bc633398..5648fe08dcc97 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8041,6 +8041,24 @@ def f(x): result = f(torch.randn(5)) self.assertEqual(result, 5) + def test_one_hot_bounds_check_compiled(self): + # https://github.com/pytorch/pytorch/issues/144211 + # torch.compile(one_hot) should raise on out-of-bounds indices, + # not silently produce wrong results. + one_hot = torch.compile(torch.nn.functional.one_hot, fullgraph=True) + + a = torch.arange(0, 5) % 3 # [0, 1, 2, 0, 1] + with self.assertRaises(RuntimeError): + one_hot(a, 1) + + torch._dynamo.reset() + with self.assertRaises(RuntimeError): + one_hot(torch.tensor([-1, 0, 1]), 3) + + torch._dynamo.reset() + expected = torch.nn.functional.one_hot(a, 3) + self.assertEqual(one_hot(a, 3), expected) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateClass.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateClass.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateFunction.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateFunction.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_copy b/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_copy b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_singleton_empty_frozenset b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_singleton_empty_frozenset deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_copy b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestUserObjects.test_dict_copy b/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_collections-TestUserObjects.test_dict_copy rename to test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_copy b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_copy rename to test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_copy b/test/dynamo_expected_failures/TestNN.test_to similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_copy rename to test/dynamo_expected_failures/TestNN.test_to diff --git a/test/dynamo_expected_failures/TestTorch.test_as_subclass b/test/dynamo_expected_failures/TestTorch.test_as_subclass deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTorchFunctionOverride.test_Tensor___cuda_array_interface_____get__ b/test/dynamo_expected_failures/TestTorchFunctionOverride.test_Tensor___cuda_array_interface_____get__ deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 103cb407e078b..63d248c5a1420 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -351,6 +351,8 @@ aten::lt.Tensor aten::lt.Tensor_out aten::lt_.Scalar aten::lt_.Tensor +aten::max_pool2d_with_indices_backward +aten::max_pool2d_with_indices_backward.grad_input aten::maximum aten::maximum.out aten::mean diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 29997570218d2..f3d6813e797ee 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -954,8 +954,6 @@ aten::max_pool2d_backward aten::max_pool2d_backward.out aten::max_pool2d_with_indices aten::max_pool2d_with_indices.out -aten::max_pool2d_with_indices_backward -aten::max_pool2d_with_indices_backward.grad_input aten::max_pool3d_with_indices aten::max_pool3d_with_indices.out aten::max_pool3d_with_indices_backward diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index fefd35ad99ead..32aaa56d69baf 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -427,7 +427,7 @@ def forward(self, x, y): for node in _ep.graph.nodes: if bindings := node.meta.get("unbacked_bindings"): unbacked_binding_symbols.update(bindings.keys()) - self.assertEqual(len(unbacked_binding_symbols), 1) + self.assertEqual(len(unbacked_binding_symbols), 2) def test_offsets(self): class M(torch.nn.Module): diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 646ec14f9a8c5..e86c5f9d60759 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: export"] # flake8: noqa +import contextlib import copy import types import unittest @@ -14,7 +15,12 @@ dynamo_graph_capture_for_export, ) from torch._dynamo.test_case import run_tests, TestCase -from torch._functorch.aot_autograd import aot_export_module +from torch._export.utils import _compiling_state_context +from torch._functorch.aot_autograd import ( + aot_export_joint_with_descriptors, + aot_export_module, +) +from torch._guards import tracing as torch_tracing, TracingContext from torch.export import export from torch.export.experimental import _export_forward_backward, _sticky_export from torch.export.graph_signature import OutputKind @@ -1465,6 +1471,91 @@ def fn(b, h, q, k): # Same closure code + different captured value -> different spec self.assertNotEqual(spec_a, spec_b) + def test_aot_export_closure_buffer_mutation(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.zeros(())) + + def forward(self, x): + self.buf.add_(x.sum()) + return x.sin() + + def make_closure(mod): + def fn(x): + mod._buffers["buf"].add_(x.sum()) + return x.sin() + + return fn + + class Wrapper(torch.nn.Module): + def __init__(self, fn, mod): + super().__init__() + self._parameters = mod._parameters + self._buffers = mod._buffers + self._modules = mod._modules + self._fn = fn + + def forward(self, x): + return self._fn(x) + + def run_export(capture_fn): + mod = Mod() + wrapped = Wrapper(make_closure(mod), mod) + x = torch.randn(4) + gm = capture_fn(wrapped)(x) + + with contextlib.ExitStack() as stack: + stack.enter_context( + torch_tracing( + gm.meta.get( + "tracing_context", TracingContext(gm.meta["fake_mode"]) + ) + ) + ) + stack.enter_context(_compiling_state_context()) + stack.enter_context(gm.meta["fake_mode"]) + + jd = aot_export_joint_with_descriptors( + stack, + gm, + args=(x,), + kwargs={}, + keep_inference_input_mutations=True, + disable_functionalization=True, + ) + return jd.graph_module, wrapped, x + + # Verify Dynamo-captured graph mutates the buffer via closure + mod = Mod() + wrapped = Wrapper(make_closure(mod), mod) + x = torch.randn(4) + gm = dynamo_graph_capture_for_export(wrapped)(x) + wrapped.buf.zero_() + gm(x) + self.assertEqual(wrapped.buf, x.sum()) + + # Verify joint graphs from both APIs match + joint_public, _, _ = run_export(dynamo_graph_capture_for_export) + joint_private, _, _ = run_export(_dynamo_graph_capture_for_export) + self.assertEqual( + str(joint_public.code).strip(), str(joint_private.code).strip() + ) + + # Verify numerical correctness of both joint graphs against eager + mod = Mod() + x = torch.randn(4) + eager_out = mod(x) + eager_buf = mod.buf.clone() + + for label, joint_gm in [("public", joint_public), ("private", joint_private)]: + buf_input = torch.zeros(()) + (exported_out,) = joint_gm(buf_input, x) + self.assertEqual(exported_out, eager_out, msg=f"{label}: output mismatch") + self.assertEqual( + buf_input, eager_buf, msg=f"{label}: buffer mutation mismatch" + ) + if __name__ == "__main__": run_tests() diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward2_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward2_cpu deleted file mode 100644 index a2f73f3214dfe..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward2_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (56,), (40,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward3_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward3_cpu deleted file mode 100644 index 4af5f1937326f..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward3_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (38,), (37,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward4_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward4_cpu deleted file mode 100644 index 3308d592808d9..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward4_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (4,), (3,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward_cpu deleted file mode 100644 index 53aaef6eda9cd..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (14,), (18,). diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7eafb99e5caf9..ebfb52c57ff46 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1919,78 +1919,6 @@ def forward(self, x, y, lengths): self.check_model(model, example_inputs, dynamic_shapes=spec) torch.cuda.caching_allocator_enable(True) - @skipIfMPS - @config.patch({"triton.autotune_at_compile_time": None}) - @torch.fx.experimental._config.patch("backed_size_oblivious", True) - def test_slice_independent_backed_symints_no_unbacked(self): - # x[0:s1] where x.size(0) = s0-1 should produce Min(s1, s0-1), - # not an unbacked symint with a bad fallback value. - if self.device != GPU_TYPE: - raise unittest.SkipTest("requires triton") - - INNER_DIM = 4224 - - class Repro(torch.nn.Module): - def forward(self, x, y): - x_trimmed = x[:-1] - sliced = x_trimmed[: y.size(0)] - reshaped = sliced.reshape(-1, 128, 33) - expanded = reshaped.unsqueeze(3).expand(-1, 128, 33, 8) - shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64) - return (expanded >> shifts) & 255 - - torch.cuda.caching_allocator_enable(False) - model = Repro() - example_inputs = ( - torch.randint( - 0, 256, (200, INNER_DIM), device=self.device, dtype=torch.int64 - ), - torch.randn(50, 8, device=self.device), - ) - spec = { - "x": (Dim.DYNAMIC, Dim.STATIC), - "y": (Dim.DYNAMIC, Dim.STATIC), - } - self.check_model(model, example_inputs, dynamic_shapes=spec) - torch.cuda.caching_allocator_enable(True) - - @skipIfMPS - @config.patch({"triton.autotune_at_compile_time": None}) - @torch.fx.experimental._config.patch("backed_size_oblivious", True) - def test_slice_negative_index_backed_symints_no_unbacked(self): - # x[-s1:] where x.size(0) = s0-1 should produce Max(s0-1 - s1, 0), - # not an unbacked symint with a bad fallback value. - if self.device != GPU_TYPE: - raise unittest.SkipTest("requires triton") - - INNER_DIM = 4224 - - class Repro(torch.nn.Module): - def forward(self, x, y): - x_trimmed = x[:-1] - sliced = x_trimmed[-y.size(0) :] - reshaped = sliced.reshape(-1, 128, 33) - expanded = reshaped.unsqueeze(3).expand(-1, 128, 33, 8) - shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64) - return (expanded >> shifts) & 255 - - torch.cuda.caching_allocator_enable(False) - try: - model = Repro() - example_inputs = ( - torch.randint( - 0, 256, (200, INNER_DIM), device=self.device, dtype=torch.int64 - ), - torch.randn(50, 8, device=self.device), - ) - spec = { - "x": (Dim.DYNAMIC, Dim.STATIC), - "y": (Dim.DYNAMIC, Dim.STATIC), - } - self.check_model(model, example_inputs, dynamic_shapes=spec) - finally: - torch.cuda.caching_allocator_enable(True) - @config.patch({"triton.autotune_at_compile_time": None}) def test_stride_with_unbacked_expr(self): class Repro(torch.nn.Module): diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 887e8ab30ee38..6ae1dc0208480 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -308,8 +308,11 @@ def forward(self, x, y): if self.device == GPU_TYPE: kernel_bin = get_kernel_bin_format(self.device) self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}"))) - # Check if .cubin.o files exist and use unique kernel names - self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o"))) + # Check that cubin binaries are embedded as object files. + # Either individual per-kernel .o files or a single combined .o. + individual_objs = list(tmp_path.glob(f"triton_*.{kernel_bin}.o")) + combined_obj = list(tmp_path.glob("cubins_combined.o")) + self.assertTrue(individual_objs or combined_obj) # Check if the .so file was build successfully so_path = build_path / "libaoti_model.so" diff --git a/test/inductor/test_auto_chunker.py b/test/inductor/test_auto_chunker.py index 25957a01bcc6d..99d00d34ca8ae 100644 --- a/test/inductor/test_auto_chunker.py +++ b/test/inductor/test_auto_chunker.py @@ -271,6 +271,31 @@ def f(x, w): self.assertTrue(same(expect, actual, tol=1e-3)) self.assertEqual(metrics.num_auto_chunking, 1) + @config.patch("auto_chunker.output_size_threshold", 1024) + @config.patch("auto_chunker.num_chunk", 2) + def test_propagate_amax_unsqueeze(self): + M, K, N = 256, 4, 256 + x = torch.randn(M, K, device=GPU_TYPE, requires_grad=True) + w = torch.randn(K, N, device=GPU_TYPE, requires_grad=True) + + def f(x, w): + out = (x * 2) @ w + max_val = out.amax(dim=-1) + out = out - max_val.unsqueeze(-1) + out = torch.exp(out) + loss = out.sum() + loss.backward() + return loss + + expect = (f(x, w), x.grad, w.grad) + x.grad = None + w.grad = None + opt_f = torch.compile(f) + actual = (opt_f(x, w), x.grad, w.grad) + + self.assertTrue(same(expect, actual, tol=1e-3)) + self.assertEqual(metrics.num_auto_chunking, 1) + def test_set_num_chunk_with_compile_options(self): B = 32 T = 1024 diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py index d4fb06706fcb1..8810b4d4ef166 100644 --- a/test/inductor/test_codegen_triton.py +++ b/test/inductor/test_codegen_triton.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import contextlib +import unittest import sympy @@ -14,8 +15,14 @@ from torch._inductor.dtype_propagation import DtypePropagationOpsHandler, promote_types from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_GPU, + HAS_GPU_AND_TRITON, +) from torch.utils._sympy.functions import FloorDiv, TruncToFloat, TruncToInt from torch.utils._sympy.value_ranges import ValueRanges @@ -182,6 +189,19 @@ def test_materialize_trunc_to_float_expr_preserves_integer_subexpressions(self): sympy.Float(0.5) + TruncToFloat(s0), ) + @unittest.skipUnless(torch.version.hip is not None, "pointer_range_32 is HIP-only") + @unittest.skipUnless(HAS_GPU_AND_TRITON, "requires GPU and Triton") + def test_pointer_range_in_generated_code(self): + """Verify tt.pointer_range=32 appears in generated Triton code on HIP.""" + + def fn(x): + return x + 1 + + x = torch.randn(64, 64, device=GPU_TYPE, dtype=torch.bfloat16) + _, code = run_and_get_code(torch.compile(fn), x) + code_str = " ".join(code) + self.assertIn("tt.pointer_range", code_str) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index b9e0888209d7a..b05515406dcbe 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -1429,6 +1429,7 @@ def test_model(x, weight): torch.testing.assert_close(result, test_x @ test_weight, rtol=1e-1, atol=1e-1) + @skipIfXpu def test_cudagraph_memory_cleanup(self): """Test that CUDA graph destruction automatically cleans up cuBLAS workspaces.""" if self.device != "cuda": @@ -1473,6 +1474,7 @@ def test_cudagraph_memory_cleanup(self): f"Memory leak detected: baseline={baseline_memory}, after_cleanup={memory_after_cleanup}", ) + @skipIfXpu def test_cudagraph_memory_cleanup_benchmarker(self): """Test that CUDA graph benchmarking cleans up memory without leaking.""" if self.device != "cuda": diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index ff6a40b51978e..1e0b5917e3d0a 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -614,6 +614,47 @@ def test_optimization_hint_with_expression_containing_precomputed_size(self): self.assertEqual(hint, 53) +class TestOptimizationHintZeroDivision(InductorTestCase): + """Test that optimization_hint handles ZeroDivisionError from ModularIndexing with zero-valued unbacked symbols.""" + + def test_modular_indexing_with_zero_divisor(self): + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + u1 = sizevars.shape_env.create_unbacked_symint().node.expr + + # u0 + 1 ensures base != 0 after substitution; u1 is the divisor. + # With fallback=0: u0->0, u1->0, so (0+1) // 0 -> ZeroDivisionError. + # optimization_hint catches ZeroDivisionError and returns fallback. + expr = ModularIndexing(u0 + 1, u1, 4) + hint = sizevars.optimization_hint(expr, fallback=0) + self.assertEqual(hint, 0) + + def test_floor_div_with_zero_divisor(self): + """optimization_hint should not crash when FloorDiv has an unbacked + symbol as divisor that gets substituted with 0.""" + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + u1 = sizevars.shape_env.create_unbacked_symint().node.expr + + # With fallback=0: u0->0, u1->0, FloorDiv(0+1, 0) -> ZeroDivisionError. + # optimization_hint catches ZeroDivisionError and returns fallback. + expr = FloorDiv(u0 + 1, u1) + hint = sizevars.optimization_hint(expr, fallback=0) + self.assertEqual(hint, 0) + + def test_modular_indexing_zero_divisor_nonzero_fallback(self): + """When fallback is nonzero, the hint should still not crash.""" + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + u1 = sizevars.shape_env.create_unbacked_symint().node.expr + + # With fallback=8192: u0->8192, u1->8192 + # (8192+1) // 8192 = 1, 1 % 4 = 1 + expr = ModularIndexing(u0 + 1, u1, 4) + hint = sizevars.optimization_hint(expr, fallback=8192) + self.assertEqual(hint, 1) + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index baf1c20c03661..b5b9e29157b94 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -16,6 +16,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfXpu, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -669,6 +670,7 @@ def f(x): # the other is the piontwise kernel self.assertTrue(2, metrics.generated_kernel_count) + @patch("torch._inductor.scheduler.MixOrderReduction.is_split_reduction") @patch("torch._inductor.scheduler.MixOrderReduction.get_numel_rnumel") @patch("torch._inductor.scheduler.MixOrderReduction.get_common_read") @patch("torch._inductor.scheduler.MixOrderReduction.has_mix_reduction_orders") @@ -677,6 +679,7 @@ def test_mix_order_reduction_non_strict_mode( mock_has_mix_reduction_orders: mock.Mock, mock_get_common_read: mock.Mock, mock_get_numel_rnumel: mock.Mock, + mock_is_split_reduction: mock.Mock, ): """ This tests whether we can skip some non-critical checks @@ -709,6 +712,7 @@ def test_mix_order_reduction_non_strict_mode( from sympy import Integer mock_get_numel_rnumel.return_value = (Integer(1), Integer(1)) + mock_is_split_reduction.return_value = False mock_node_1.read_writes = mock.Mock() mock_node_1.read_writes.reads = [] @@ -730,7 +734,11 @@ def test_mix_order_reduction_non_strict_mode( self.assertFalse(MixOrderReduction.can_fuse(mock_node_1, mock_node_2)) with ( V.set_graph_handler(graph), - inductor_config.patch({"triton.mix_order_reduction_non_strict_mode": True}), + inductor_config.patch( + { + "triton.mix_order_reduction_non_strict_mode": True, + } + ), ): self.assertTrue(MixOrderReduction.can_fuse(mock_node_1, mock_node_2)) @@ -758,6 +766,7 @@ def f(x): compile_metrics = torch._dynamo.utils._compilation_metrics self.assertEqual(len(compile_metrics), 1, "Don't recompile") + @skipIfXpu(msg="https://github.com/intel/intel-xpu-backend-for-triton/issues/6398") def test_additive_rnumel(self): """ Fix https://github.com/pytorch/pytorch/issues/176375 @@ -937,6 +946,57 @@ def causal_mask(_b, _h, q, kv): loss.backward() self.assertTrue(metrics.codegen_mix_order_reduction > 1) + @inductor_config.patch("triton.mix_order_reduction", True) + @inductor_config.patch("triton.mix_order_reduction_non_strict_mode", True) + def test_dimension_refactoring_mismatch(self): + """ + This reproduces an issue where `simplify_and_reorder()` produces a different + dimension factorization than `_original_ranges` used during fusion decision. + For example, fusion might see (13, 8472) but codegen sees (26, 4236) after + the reduction split optimization adds a factor of 2 to the pointwise dimensions. + + We skip fusing split reductions for node1 in this case. + """ + + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + # Reproduce the RMSNorm backward pattern that triggered the bug. + # The key is: + # - Shape (M, N) = (13, 8472) where N=8472 is large enough to trigger split + # - RMSNorm backward creates reductions along both dimensions + # - The feature dimension reduction (8472) gets split with factor 2 + # - Mix order reduction tries to fuse these, but groups don't match after split + def f(x, w, eps): + orig_dtype = x.dtype + x = x.float() + # RMSNorm forward: y = x * rsqrt(mean(x^2) + eps) * w + rsqrt = torch.rsqrt((x * x).sum(dim=-1) / x.shape[-1] + eps) + y = (x * rsqrt[:, None] * w).to(dtype=orig_dtype) + return y + + def fwd_bwd(compiled_f): + x.grad = None + w.grad = None + out = compiled_f(x, w, eps) + out.backward(dy) + return x.grad, w.grad + + # Use the exact shape from the bug report: (13, 8472) + # 8472 = 2 * 4236, so split with factor 2 gives sub-reductions of 4236 + M, N = 13, 8472 + x = torch.randn(M, N, dtype=torch.float32, device=GPU_TYPE, requires_grad=True) + w = torch.randn(N, dtype=torch.float32, device=GPU_TYPE, requires_grad=True) + dy = torch.randn_like(x) + eps = 1e-5 + + opt_f = torch.compile(f) + + ref = fwd_bwd(f) + act = fwd_bwd(opt_f) + torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) + self.assertGreaterEqual(metrics.codegen_mix_order_reduction, 0) + @inductor_config.patch( "triton.mix_order_reduction", not inductor_config.triton.mix_order_reduction diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index a08339594056b..dcd3ca6689cdf 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1348,6 +1348,7 @@ def remap_fake_tensor(x): "max_autotune_gemm_backends": "TRITON", } ) + @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_original_aten_preserved_split_addmm(self): # addmm -> elementwise should be decomposed into mm -> add -> elementwise def fn(x, y, z): diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 0670be402e42c..7d87a91ba323f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3475,6 +3475,8 @@ def fn(a): check_lowp=False, ) + test_one_hot._expected_failure_halide = True + def test_div1(self): def fn(a, b): return ( @@ -3557,6 +3559,7 @@ def fn(a, b): ) @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process + @skipIfXpu(msg="https://github.com/intel/intel-xpu-backend-for-triton/issues/6401") def test_div7(self): def fn(a, b): return ( @@ -6301,6 +6304,94 @@ def fn(): actual = compiled() self.assertEqual(actual, expected) + def test_complex_uniform_constant_folding(self): + # Fix https://github.com/pytorch/pytorch/issues/174891 + # view.dtype with mismatched element sizes changes element count, + # so constant folding must not treat the result as uniform. + def fn(x): + mask = torch.ones(2, 1, dtype=torch.complex64, device=self.device) + return x + mask + + x = torch.full((2, 2), 1.0, dtype=torch.complex64, device=self.device) + expected = fn(x) + compiled = torch.compile(fn, backend="inductor") + actual = compiled(x) + self.assertEqual(actual, expected) + + def test_view_dtype_non_0d_larger_to_smaller_element_size(self): + # Non-0-d counterpart of test_view_dtype_0d_smaller_to_larger_element_size. + # element_size (8) > itemsize (4): complex64 -> float32. + import torch.fx as fx + from torch._inductor.fx_passes.joint_graph import UniformValueConstantFolder + + graph = fx.Graph() + + full_node = graph.call_function( + torch.ops.aten.full.default, + args=([2], 1 + 0j), + kwargs={ + "dtype": torch.complex64, + "layout": torch.strided, + "device": self.device, + "pin_memory": False, + }, + ) + full_node.meta["val"] = torch.full( + [2], 1 + 0j, dtype=torch.complex64, device=self.device + ) + + view_node = graph.call_function( + torch.ops.aten.view.dtype, args=(full_node, torch.float32) + ) + view_node.meta["val"] = torch.full( + [2], 1 + 0j, dtype=torch.complex64, device=self.device + ).view(torch.float32) + + graph.output(view_node) + gm = fx.GraphModule(torch.nn.Module(), graph) + + folder = UniformValueConstantFolder(gm) + folder.run() + + self.assertNotIn(view_node, folder.node_replacements) + + def test_view_dtype_non_0d_smaller_to_larger_element_size(self): + # Non-0-d counterpart of test_view_dtype_0d_smaller_to_larger_element_size. + # element_size (4) < itemsize (8): float32 -> complex64. + import torch.fx as fx + from torch._inductor.fx_passes.joint_graph import UniformValueConstantFolder + + graph = fx.Graph() + + full_node = graph.call_function( + torch.ops.aten.full.default, + args=([2], 1.0), + kwargs={ + "dtype": torch.float32, + "layout": torch.strided, + "device": self.device, + "pin_memory": False, + }, + ) + full_node.meta["val"] = torch.full( + [2], 1.0, dtype=torch.float32, device=self.device + ) + + view_node = graph.call_function( + torch.ops.aten.view.dtype, args=(full_node, torch.complex64) + ) + view_node.meta["val"] = torch.full( + [2], 1.0, dtype=torch.float32, device=self.device + ).view(torch.complex64) + + graph.output(view_node) + gm = fx.GraphModule(torch.nn.Module(), graph) + + folder = UniformValueConstantFolder(gm) + folder.run() + + self.assertNotIn(view_node, folder.node_replacements) + def test_uniform(self): def fn(x): return aten.uniform.default(x, 0, 1) @@ -10369,11 +10460,13 @@ def fn(a, b, c): indices, ], ) - assertGeneratedKernelCountEqual(self, 1) + # Note: Kernel count varies by backend (CUDA ~3, ROCm ~2) due to fusion. + # Correctness is validated by self.common() above. + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) @expectedFailureXPU def test_max_pool2d_with_indices_backward5(self): - # Window size is too big. Should fallback + # Large window size - decomposition handles via scatter_add def fn(a, b, c): return aten.max_pool2d_with_indices_backward( a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c @@ -10397,11 +10490,15 @@ def fn(a, b, c): indices, ], ) - assertGeneratedKernelCountEqual(self, 0) + # Note: Kernel count varies by backend (CUDA ~3, ROCm ~2) due to fusion. + # Correctness is validated by self.common() above. + # MPS: decomposition falls back to native kernel, so no inductor kernels generated + if self.device != "mps": + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) # From https://github.com/pytorch/pytorch/issues/93384 def test_max_pool2d_with_indices_backward6(self): - # dilation is not 1. Should fallback + # dilation != 1 - decomposition handles all dilation cases def fn(a, b, c): return aten.max_pool2d_with_indices_backward( a, b, [3, 2], [2, 1], [1, 1], [1, 2], False, c @@ -10425,7 +10522,11 @@ def fn(a, b, c): indices, ], ) - assertGeneratedKernelCountEqual(self, 0) + # Note: Kernel count varies by backend (CUDA ~3, ROCm ~2) due to fusion. + # Correctness is validated by self.common() above. + # MPS: decomposition falls back to native kernel, so no inductor kernels generated + if self.device != "mps": + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) def test_issue102546(self): def fn(x): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 38ae6dace6b92..490c41fa443b3 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -216,12 +216,6 @@ def run(*ex, **kwargs): "test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)), "test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)), - "test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure( - ("cpu", "cuda") - ), - "test_max_pool2d_with_indices_backward6_dynamic_shapes": TestFailure( - ("cpu", "cuda", "xpu") - ), "test_misaligned_address_issue1_dynamic_shapes": TestFailure(("cpu",)), "test_mm_views_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_new_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index f44f1e95033bc..7687c412c4974 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -30,6 +30,7 @@ MI350_ARCH, parametrize, serialTest, + skipIfRocm, skipIfRocmArch, TEST_CUDA_MEM_LEAK_CHECK, TEST_WITH_ASAN, @@ -644,6 +645,7 @@ def f(x, w): torch.compile(fullgraph=True)(f)(x, w).sum().backward() self.assertEqual(orig_w, w.grad) + @skipIfRocm # regression in ROCm 7.2, XBLOCK should remain 64 (got 256) @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index b6b6241e1f218..e729d9575c8a2 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -1319,7 +1319,7 @@ def fn(*args, **kwargs): # not exercised in test_ops_gradients atm. The problem is not # complex32 per-se (which is supported by data movement only ops) # but that when we do backwards we expect other ops like add to work - and dtype not in (torch.complex32, torch.bcomplex32) + and dtype != torch.complex32 ) samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) extra = _inductor_extra_samples(op_name, device, dtype, requires_grad) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 0b1246e9db1eb..ef2a2b2aff254 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -485,11 +485,11 @@ def join_threads(context: bool): self.assertEqual(len(observed_during_run), worker_threads) self.assertEqual(len(observed_during_run), len(set(observed_during_run))) - def payload(self, use_cuda=False): - x = torch.randn(10, 10) + def payload(self, use_cuda=False, tensor_size=10): + x = torch.randn(tensor_size, tensor_size) if use_cuda: x = x.cuda() - y = torch.randn(10, 10) + y = torch.randn(tensor_size, tensor_size) if use_cuda: y = y.cuda() z = torch.mm(x, y) @@ -1540,6 +1540,40 @@ def test_profiler_disable_fwd_bwd_link(self): finally: torch._C._profiler._set_fwd_bwd_enabled_val(True) + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_profiler_external_id_parity(self): + """Verify that FunctionEvent.external_id matches External id in Chrome trace JSON.""" + from collections import Counter + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("test_region"): + x = torch.randn(32, 32, device="cuda") + y = torch.mm(x, x) + z = y + x + z.cpu() + torch.cuda.synchronize() + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + json_name_ext = Counter( + (e["name"], e["args"]["External id"]) + for e in j["traceEvents"] + if e.get("args", {}).get("External id") is not None + ) + events_name_ext = Counter( + (ev.name, ev.external_id) for ev in prof.events() if ev.external_id != 0 + ) + + self.assertEqual( + events_name_ext, + json_name_ext, + "(name, external_id) pairs differ between events() and Chrome trace JSON", + ) + @unittest.skipIf(not kineto_available(), "Kineto is required") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_profiler_cuda_sync_events(self): @@ -2452,7 +2486,7 @@ def validate_json(prof, disable_external_correlation): disable_external_correlation=disable_external_correlation ), ) as prof: - self.payload(use_cuda=True) + self.payload(use_cuda=True, tensor_size=256) validate_json(prof, disable_external_correlation) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @@ -2816,6 +2850,27 @@ def test_activity_filter_empty_list(self): y = torch.mm(x, x) self.assertEqual(len(p.events()), 0) + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not TEST_CUDA, "CUDA is required") + def test_kineto_kernel_metadata_in_trace(self): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + self.payload(use_cuda=True) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + trace = json.load(f) + events = trace["traceEvents"] + kernel_events = [e for e in events if e.get("cat", "") == "kernel"] + self.assertGreater( + len(kernel_events), 0, "Error: No kernel events in trace" + ) + for ke in kernel_events: + args = ke.get("args", {}) + name = ke.get("name", "") + for key in ["device", "stream", "correlation", "grid", "block"]: + self.assertIn(key, args, f"kernel '{name}' missing '{key}'") + class SimpleNet(nn.Module): def __init__(self) -> None: diff --git a/test/test_autograd.py b/test/test_autograd.py index 2eb7c63e1b6f2..9b2c57ab1700c 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5075,6 +5075,19 @@ def f(x): self.assertTrue(torch.autograd.is_view_replay_enabled()) self.assertFalse(torch.autograd.is_view_replay_enabled()) + prev = torch.autograd.is_view_replay_enabled() + ctx = torch.autograd._force_original_view_tracking(not prev) + # Construction eagerly sets state (function-form behavior). + self.assertEqual(torch.autograd.is_view_replay_enabled(), not prev) + with ctx: + self.assertEqual(torch.autograd.is_view_replay_enabled(), not prev) + out = f(x) + self.assertTrue( + ("ViewBackward" if not prev else "AsStridedBackward") + in str(out.grad_fn) + ) + self.assertEqual(torch.autograd.is_view_replay_enabled(), prev) + # Test as a function torch.autograd._force_original_view_tracking(False) out = f(x) @@ -5086,6 +5099,20 @@ def f(x): self.assertTrue("ViewBackward" in str(out.grad_fn)) self.assertTrue(torch.autograd.is_view_replay_enabled()) + prev = torch.autograd.is_view_replay_enabled() + + @torch.autograd._force_original_view_tracking(not prev) + def g(x): + return f(x) + + # __call__ undoes the __init__ mutation, so ambient state is restored. + self.assertEqual(torch.autograd.is_view_replay_enabled(), prev) + out = g(x) + self.assertTrue( + ("ViewBackward" if not prev else "AsStridedBackward") in str(out.grad_fn) + ) + self.assertEqual(torch.autograd.is_view_replay_enabled(), prev) + def test_unsafe_set_version_counter(self): x = torch.ones(2, requires_grad=True).clone() x.add_(1) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 3f28e39fa6920..29dff031d6e5a 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1413,14 +1413,14 @@ def test_pow(self, device, dtype): else: self._do_pow_for_exponents(m1, exponents, math.pow, None) will_raise_error = ( - dtype == torch.half and torch.device(device).type == "cpu" - ) or dtype == torch.bfloat16 + dtype is torch.half and torch.device(device).type == "cpu" + ) if will_raise_error: - # On CPU, Half/BFloat16 Tensor with complex exponents leads to - # computation dtype of ComplexHalf/BComplex32 for which this ops is not - # supported yet + # On CPU, + # Half Tensor with complex exponents leads to computation dtype + # of ComplexHalf for which this ops is not supported yet with self.assertRaisesRegex( - RuntimeError, "not implemented for '(ComplexHalf|BComplex32)'" + RuntimeError, "not implemented for 'ComplexHalf'" ): self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) else: @@ -3627,7 +3627,7 @@ def _test_helper(a, b): ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy()) v = our_func(a, b) self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01) - elif dtype in (torch.complex32, torch.bcomplex32): + elif dtype == torch.complex32: ref = ref_func( a.cpu().to(torch.complex64).numpy(), b.cpu().to(torch.complex64).numpy(), diff --git a/test/test_cuda.py b/test/test_cuda.py index b69e3ba0867d1..eb84193ffcf65 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1704,6 +1704,13 @@ def test_norm_type_conversion(self): a = torch.ones(65536).cuda().half() self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536) + @unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory") + @serialTest() + def test_cuda_opaque_type(self): + x = torch.ones(600_000_000, dtype=torch.int32, device="cuda") + y = torch.where(x > 0, x, x) + self.assertEqual(y, x) + def test_cuda_memory_leak_detection_propagates_errors(self): with self.assertRaisesRegex( RuntimeError, r"The size of tensor a \(3\) must match" @@ -5071,14 +5078,6 @@ def cb(device, alloc, device_alloc, device_free): def test_allocator_fuzz(self): # fuzz - if ( - torch.version.hip - and "expandable_segments:True" - in torch._C._accelerator_getAllocatorSettings() - ): - raise unittest.SkipTest( - "ROCm needs https://github.com/ROCm/rocm-systems/pull/3023" - ) state = random.getstate() random.seed(123) N = 10000 @@ -6718,6 +6717,7 @@ def test_graph_capture_pre_capture_stream_use(self): "graph_capture_record_stream_reuse:False" ) + @skipIfRocm(msg="expandable_segments mode is not supported on ROCm") @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode") def test_mempool_expandable(self): torch.cuda.empty_cache() diff --git a/test/test_cuda_expandable_segments.py b/test/test_cuda_expandable_segments.py index 25c2e9eaff2c5..262c53ab23ca4 100644 --- a/test/test_cuda_expandable_segments.py +++ b/test/test_cuda_expandable_segments.py @@ -12,7 +12,7 @@ import torch from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent @@ -25,7 +25,12 @@ sys.path.remove(str(REPO_ROOT)) if __name__ == "__main__": - if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS: + if ( + torch.cuda.is_available() + and not IS_JETSON + and not IS_WINDOWS + and not TEST_WITH_ROCM + ): get_disabled_tests(".") torch.cuda.memory._set_allocator_settings("expandable_segments:True") diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 2f0d8bff0c341..effdacc58852b 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -350,6 +350,7 @@ def test_device_inplace_copy(self): if y.copy_(x).device.type != "cuda": raise AssertionError("expected cuda device") + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_fake_device(self): t = torch.ones(3) t = t.view(1, 3) diff --git a/test/test_nn.py b/test/test_nn.py index 75fc0831cfdd1..8090f5f751799 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2894,61 +2894,6 @@ def test_CTCLoss_zero_infinity(self): self.assertEqual(g1, g2, atol=1e-4, rtol=0) self.assertTrue((g1 == g1).all().item()) # check that we don't have NaN - @skipIfRocm - @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - def test_CTCLoss_zero_infinity_cudnn(self): - # Example where model is confidently wrong (probability concentrated on wrong output), - # producing divergent loss - probs = torch.nn.functional.one_hot(torch.tensor([0], device='cuda'), num_classes=2).float() - log_probs = torch.log(probs).unsqueeze(1).requires_grad_() - targets = torch.tensor([1], device='cuda', dtype=torch.int32) - input_lengths = torch.tensor([1], device='cuda', dtype=torch.int32) - target_lengths = torch.tensor([1], device='cuda', dtype=torch.int32) - - self.assertTrue( - torch._use_cudnn_ctc_loss( - log_probs=log_probs, - targets=targets, - input_lengths=input_lengths, - target_lengths=target_lengths, - blank=0, - ) - ) - - loss_false = torch.nn.functional.ctc_loss( - log_probs, targets, input_lengths, target_lengths, reduction='sum', zero_infinity=False - ) - self.assertFalse(torch.isfinite(loss_false)) - - loss_true = torch.nn.functional.ctc_loss( - log_probs, targets, input_lengths, target_lengths, reduction='sum', zero_infinity=True - ) - self.assertTrue(torch.isfinite(loss_true)) - - @skipIfRocm - @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - def test_CTCLoss_zero_infinity_cudnn_grad(self): - probs = torch.nn.functional.one_hot(torch.tensor([0], device='cuda'), num_classes=2).float() - log_probs = torch.log(probs).unsqueeze(1).requires_grad_() - targets = torch.tensor([1], device='cuda', dtype=torch.int32) - input_lengths = torch.tensor([1], device='cuda', dtype=torch.int32) - target_lengths = torch.tensor([1], device='cuda', dtype=torch.int32) - - # The example inputs above should produce a divergent gradient, but the deterministic implementation - # of the cuDNN CTC loss (which is the only implementation reachable from the public API) returns a - # finite gradient. For this reason, the private, non-deterministic implementation is used here. - loss_false, _ = torch._cudnn_ctc_loss( - log_probs, targets, input_lengths, target_lengths, blank=0, deterministic=False, zero_infinity=False - ) - grad_false, = torch.autograd.grad(loss_false, log_probs) - self.assertFalse(torch.isfinite(grad_false).all()) - - loss_true, _ = torch._cudnn_ctc_loss( - log_probs, targets, input_lengths, target_lengths, blank=0, deterministic=False, zero_infinity=True - ) - grad_true, = torch.autograd.grad(loss_true, log_probs) - self.assertTrue(torch.isfinite(grad_true).all()) - def test_RNN_cell_no_broadcasting(self): def test(cell_module, input, hx, input_size, hidden_size): cell = cell_module(input_size, hidden_size) diff --git a/test/test_ops.py b/test/test_ops.py index 5241f56d04d91..579b592e883c3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -39,6 +39,7 @@ from torch.testing._internal.common_dtype import ( all_types_and_complex_and, floating_and_complex_types_and, + highest_precision_float, integral_types_and, ) from torch.testing._internal.common_methods_invocations import ( @@ -467,6 +468,7 @@ def test_reduction_ops_reduce(self, device, op): # resulting in possible equality check failures. # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 # XPU test will be enabled step by step. Skip the tests temporarily. + # MPS does not support double precision, so single precision has to be used instead. @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @suppress_warnings @@ -483,7 +485,7 @@ def test_numpy_ref(self, device, dtype, op): raise unittest.SkipTest("XXX: raises tensor-likes are not close.") # Sets the default dtype to NumPy's default dtype of double - with set_default_dtype(torch.double): + with set_default_dtype(highest_precision_float(device)): for sample_input in op.reference_inputs(device, dtype): self.compare_with_reference( op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) @@ -601,11 +603,10 @@ def _ref_test_helper( if skip_bfloat and ( ( isinstance(sample.input, torch.Tensor) - and sample.input.dtype in {torch.bfloat16, torch.bcomplex32} + and sample.input.dtype == torch.bfloat16 ) or any( - isinstance(arg, torch.Tensor) - and arg.dtype in {torch.bfloat16, torch.bcomplex32} + isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16 for arg in sample.args ) ): @@ -636,9 +637,13 @@ def _ref_test_helper( # precise dtypes -- they simply must be close precise_dtype = dtype if prims.utils.is_float_dtype(dtype): - precise_dtype = torch.double + precise_dtype = highest_precision_float(device) if prims.utils.is_complex_dtype(dtype): - precise_dtype = torch.cdouble + precise_dtype = ( + torch.complex32 + if torch.device(device).type == "mps" + else torch.cdouble + ) # Checks if the results are close try: @@ -1542,20 +1547,12 @@ def test_complex_half_reference_testing(self, device, dtype, op): unittest.skip("Does not support complex32") for sample in op.sample_inputs(device, dtype): - # MPS doesn't support float64 - if torch.float64 in ( - *sample.args, - *sample.kwargs.values(), - ) and not op.supports_dtype(torch.float64, device): - continue - actual = op(sample.input, *sample.args, **sample.kwargs) # sample.transform applies the lambda to torch.Tensor and torch.dtype. # However, we only want to apply it to Tensors with dtype `torch.complex32`.. transformed_sample = sample.transform( lambda x: x.to(torch.complex64) - if isinstance(x, torch.Tensor) - and x.dtype in (torch.complex32, torch.bcomplex32) + if isinstance(x, torch.Tensor) and x.dtype is torch.complex32 else x ) expected = op( diff --git a/test/test_schema_check.py b/test/test_schema_check.py index fdb02ebf9b49f..91d9a484d3c89 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -500,9 +500,9 @@ class TestSchemaCheckModeOpInfo(JitTestCase): @ops(op_db, dtypes=OpDTypes.supported) @slowTestIf(IS_WINDOWS) def test_schema_correctness(self, device, dtype, op): - # Currently torch.equal isn't supported with torch.complex32 or torch.bcomplex32. + # Currently torch.equal isn't supported with torch.complex32 # There's also errors with complex64 and complex128 - if (dtype in (torch.complex32, torch.bcomplex32)): + if (dtype == torch.complex32): return for sample in op.sample_inputs(device, dtype, requires_grad=False): with SchemaCheckMode(): diff --git a/test/test_sparse.py b/test/test_sparse.py index 83ac96723ee92..87a9bccd22dfc 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -5770,10 +5770,7 @@ def test_view_as_real(self, device, dtype): self.assertEqual(res.shape, xs.shape + (2,)) self.assertEqual(res._values()[..., 0], xs._values().real) self.assertEqual(res._values()[..., 1], xs._values().imag) - if not ( - dtype in (torch.complex32, torch.bcomplex32) - and torch.device(device).type == "cpu" - ): + if not (dtype is torch.complex32 and torch.device(device).type == "cpu"): # ComplexHalf to_dense() is not supported on CPU. self.assertEqual(res.to_dense(), torch.view_as_real(xs.to_dense())) self.assertEqual(torch.view_as_complex(torch.view_as_real(xs)), xs) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 5dc9de8609575..fa865499cc8ea 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2257,8 +2257,8 @@ def test_mul_scalar(self, layout, device, dtype, enable_hybrid): for sparse in self.generate_simple_inputs( layout, device=device, dtype=dtype, index_dtype=torch.int32, enable_hybrid=enable_hybrid): for scalar_dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half): - # ComplexHalf/BComplex32 is experimental - if dtype in (torch.half, torch.bfloat16) and scalar_dtype.is_complex: + # ComplexHalf is experimental + if dtype is torch.half and scalar_dtype.is_complex: continue scalar_t = torch.tensor(2, dtype=scalar_dtype) diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 696c1571e50a8..4a7dac9f0bdd8 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -221,11 +221,8 @@ def test_fft_round_trip(self, device, dtype): } y = backward(forward(x, **kwargs), **kwargs) - if ( - (x.dtype is torch.half and y.dtype is torch.complex32) or - (x.dtype is torch.bfloat16 and y.dtype is torch.bcomplex32) - ): - # Since type promotion currently doesn't work with [b]complex32 + if x.dtype is torch.half and y.dtype is torch.complex32: + # Since type promotion currently doesn't work with complex32 # manually promote `x` to complex32 x = x.to(torch.complex32) # For real input, ifft(fft(x)) will convert to complex diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 347371b7f9171..abb32d525bf71 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -140,10 +140,8 @@ def complex_scalar_tensor_test(s, t): expected_dtype = s.dtype else: expected_dtype = float_to_corresponding_complex_type_map[torch.get_default_dtype()] - # Note (bcomplex32): Remove the guard against dtype once bcomplex32 is more widely supported. - if expected_dtype != torch.bcomplex32: - self.assertEqual((s * t).dtype, expected_dtype) - self.assertEqual((t * s).dtype, expected_dtype) + self.assertEqual((s * t).dtype, expected_dtype) + self.assertEqual((t * s).dtype, expected_dtype) self.assertEqual(torch.result_type(s, t), expected_dtype) self.assertEqual(torch.result_type(t, s), expected_dtype) @@ -262,11 +260,9 @@ def test_bfloat16(self, device): self.assertEqual((bf + scalar).dtype, torch.bfloat16) self.assertEqual(scalar + bf, bf + scalar) - # Note (bcomplex32): Add scalar complex testing back once bcomplex32 - # is more widely requested. - # for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)): - # self.assertEqual((bf + scalar).dtype, torch.cfloat) - # self.assertEqual(bf + scalar, scalar + bf) + for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)): + self.assertEqual((bf + scalar).dtype, torch.cfloat) + self.assertEqual(bf + scalar, scalar + bf) # with tensor for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): @@ -493,9 +489,6 @@ def _get_dtype(x): dtype_b = _get_dtype(b) try: result = a + b - except NotImplementedError: - # Note (bcomplex32): Remove this branch when bcomplex32 ops are more widely implemented. - pass except RuntimeError: with self.assertRaises(RuntimeError): torch.promote_types(dtype_a, dtype_b) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 7db412e68854a..6fa4bd4705701 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -241,7 +241,7 @@ def _helper_reference_numerics( torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t) if dtype is torch.bfloat16: a = t.cpu().to(torch.float32).numpy() - elif dtype in (torch.complex32, torch.bcomplex32): + elif dtype is torch.complex32: a = t.cpu().to(torch.complex64).numpy() else: a = t.cpu().numpy() @@ -1911,6 +1911,33 @@ def test_mvlgamma_integer_promotion(self, device, dtype): self.assertTrue(result.dtype.is_floating_point) self.assertTrue(torch.all(torch.isfinite(result))) + @onlyCUDA + @dtypes(torch.float32, torch.float16, torch.bfloat16) + def test_fp8_e4m3fn_conversion_subnormals(self, device, dtype): + # Regression test for ptxas codegen bug on sm_100 where FADD in the + # subnormal conversion path gets wrong source register for odd elements + # in the 8-wide unrolled vectorized_elementwise_kernel. + # e4m3fn subnormals: |x| < 2^-6 + torch.manual_seed(0) + N = 2**20 + x = (torch.randn(N, dtype=dtype, device=device) * 1e-3).clamp(-448, 448) + y = x.to(torch.float8_e4m3fn) + ref = x.cpu().float().to(torch.float8_e4m3fn) + self.assertEqual(y.cpu().view(torch.uint8), ref.view(torch.uint8)) + + @onlyCUDA + @dtypes(torch.float32, torch.float16, torch.bfloat16) + def test_fp8_e5m2_conversion_subnormals(self, device, dtype): + # Same regression test for e5m2. + # e5m2 subnormals: |x| < 2^-14 + torch.manual_seed(0) + N = 2**20 + x = (torch.randn(N, dtype=dtype, device=device) * 1e-4).clamp(-57344, 57344) + y = x.to(torch.float8_e5m2) + ref = x.cpu().float().to(torch.float8_e5m2) + self.assertEqual(y.cpu().view(torch.uint8), ref.view(torch.uint8)) + + instantiate_device_type_tests(TestUnaryUfuncs, globals()) if __name__ == "__main__": diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py index 97ea94ba26a34..0cfd74f9ff1e1 100644 --- a/test/test_varlen_attention.py +++ b/test/test_varlen_attention.py @@ -1,7 +1,7 @@ # Owner(s): ["module: sdpa"] import unittest from collections import namedtuple -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext import torch import torch.nn as nn @@ -11,10 +11,19 @@ restore_flash_attention_impl, ) from torch.nn.attention.varlen import varlen_attn, varlen_attn_out -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + IS_SM90, + PLATFORM_SUPPORTS_FLASH_ATTENTION, + SM100OrLater, +) from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase -from torch.testing._internal.common_utils import parametrize, run_tests, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + decorateIf, + parametrize, + run_tests, + TEST_WITH_ROCM, +) from torch.utils._python_dispatch import TorchDispatchMode @@ -32,6 +41,22 @@ def use_fa3(): restore_flash_attention_impl() +@contextmanager +def use_fa4(): + try: + activate_flash_attention_impl("FA4") + except (ModuleNotFoundError, RuntimeError) as err: + raise unittest.SkipTest("FA4 backend not available") from err + try: + yield + finally: + restore_flash_attention_impl() + + +def _use_backend(backend): + return {"fa2": nullcontext, "fa3": use_fa3, "fa4": use_fa4}[backend]() + + VarlenShape = namedtuple( "VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"] ) @@ -215,7 +240,11 @@ class TestVarlenAttention(NNTestCase): not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_basic_functionality(self, device, dtype): + @parametrize( + "backend", + ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if SM100OrLater else []), + ) + def test_basic_functionality(self, device, dtype, backend): torch.manual_seed(42) shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) @@ -236,39 +265,47 @@ def test_basic_functionality(self, device, dtype): [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 ) - output = attention_block.forward_varlen(x_packed, cu_seq, shape.max_seq_len) + with _use_backend(backend): + output = attention_block.forward_varlen(x_packed, cu_seq, shape.max_seq_len) - self.assertEqual(output.shape, (total_tokens, shape.embed_dim)) - self.assertEqual(output.device, torch.device(device)) - self.assertEqual(output.dtype, dtype) + self.assertEqual(output.shape, (total_tokens, shape.embed_dim)) + self.assertEqual(output.device, torch.device(device)) + self.assertEqual(output.dtype, dtype) - # varlen_attn_out should produce the same result and write into the buffer - with torch.no_grad(): - q, k, v = attention_block.get_varlen_qkv(x_packed) - expected = varlen_attn( - q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len - ) - out_buf = torch.empty_like(expected) - actual = varlen_attn_out( - out_buf, q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len - ) - self.assertEqual(actual.data_ptr(), out_buf.data_ptr()) - self.assertEqual(out_buf, expected) - - varlen_grad_out = torch.ones_like(output) - - varlen_grad = torch.autograd.grad( - outputs=output, - inputs=x_packed, - grad_outputs=varlen_grad_out, - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] + # varlen_attn_out should produce the same result and write into the buffer + with torch.no_grad(): + q, k, v = attention_block.get_varlen_qkv(x_packed) + expected = varlen_attn( + q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len + ) + out_buf = torch.empty_like(expected) + actual = varlen_attn_out( + out_buf, + q, + k, + v, + cu_seq, + cu_seq, + shape.max_seq_len, + shape.max_seq_len, + ) + self.assertEqual(actual.data_ptr(), out_buf.data_ptr()) + self.assertEqual(out_buf, expected) + + varlen_grad_out = torch.ones_like(output) + + varlen_grad = torch.autograd.grad( + outputs=output, + inputs=x_packed, + grad_outputs=varlen_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] - self.assertIsNotNone(varlen_grad) - self.assertEqual(varlen_grad.shape, x_packed.shape) - self.assertEqual(varlen_grad.dtype, x_packed.dtype) + self.assertIsNotNone(varlen_grad) + self.assertEqual(varlen_grad.shape, x_packed.shape) + self.assertEqual(varlen_grad.dtype, x_packed.dtype) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" @@ -430,7 +467,11 @@ def run_varlen_out(q, k, v, cu_seq, max_len): (1025, 1025), ], ) - def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): + @parametrize( + "backend", + ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if SM100OrLater else []), + ) + def test_varlen_vs_sdpa(self, device, dtype, scale, window_size, backend): torch.manual_seed(42) shape = VarlenShape( @@ -459,13 +500,14 @@ def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): golden_attention_block.out_proj.weight.to(dtype) ) - varlen_output = attention_block.forward_varlen( - x_packed, - cu_seq, - max_len, - scale=scale, - window_size=window_size, - ) + with _use_backend(backend): + varlen_output = attention_block.forward_varlen( + x_packed, + cu_seq, + max_len, + scale=scale, + window_size=window_size, + ) sdpa_output = attention_block.forward_sdpa( x_padded, seq_lengths, @@ -499,29 +541,30 @@ def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): start_idx = end_idx - grad_out = torch.randn_like(varlen_output) - sdpa_grad_out = torch.zeros_like(sdpa_output) - golden_sdpa_grad_out = torch.zeros( - shape.batch_size, - max_len, - shape.embed_dim, - device=device, - dtype=torch.float32, - ) - start_idx = 0 - for i, seq_len in enumerate(seq_lengths): - end_idx = start_idx + seq_len - sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx] - golden_sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx].to( - torch.float32 + with _use_backend(backend): + grad_out = torch.randn_like(varlen_output) + sdpa_grad_out = torch.zeros_like(sdpa_output) + golden_sdpa_grad_out = torch.zeros( + shape.batch_size, + max_len, + shape.embed_dim, + device=device, + dtype=torch.float32, ) - start_idx = end_idx + start_idx = 0 + for i, seq_len in enumerate(seq_lengths): + end_idx = start_idx + seq_len + sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx] + golden_sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx].to( + torch.float32 + ) + start_idx = end_idx - varlen_grad = torch.autograd.grad( - outputs=varlen_output, - inputs=x_packed, - grad_outputs=grad_out, - )[0] + varlen_grad = torch.autograd.grad( + outputs=varlen_output, + inputs=x_packed, + grad_outputs=grad_out, + )[0] sdpa_grad = torch.autograd.grad( outputs=sdpa_output, @@ -562,6 +605,7 @@ def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) + @unittest.skipIf(not IS_SM90, "FA3 requires compute capability 9.0") @parametrize("dtype", [torch.bfloat16, torch.float16]) @parametrize("num_splits", [1, None]) @parametrize( @@ -618,6 +662,7 @@ def test_batch_invariance( all_k = torch.cat([target_k, extra_k], dim=0) all_v = torch.cat([target_v, extra_v], dim=0) + # fa4 is batch invariant (num_splits=1) by default with use_fa3(), torch.no_grad(): solo_output = varlen_attn( target_q, @@ -683,6 +728,11 @@ def test_batch_invariance( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @unittest.skipIf(TEST_WITH_ROCM, "ROCm does not support seqused_k") + @decorateIf( + unittest.expectedFailure, + lambda params: params["backend"] != "fa2" + and any(kv_len < 128 for kv_len in params["actual_kv_lens"]), + ) @parametrize("dtype", [torch.bfloat16, torch.float16]) @parametrize( "actual_kv_lens", @@ -694,7 +744,11 @@ def test_batch_invariance( [127, 63, 33, 17], ], ) - def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): + @parametrize( + "backend", + ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if SM100OrLater else []), + ) + def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens, backend): torch.manual_seed(42) batch_size = 4 @@ -748,7 +802,7 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): ) seqused_k = torch.tensor(actual_kv_lens, device=device, dtype=torch.int32) - with torch.no_grad(): + with _use_backend(backend), torch.no_grad(): output_cached = varlen_attn( q_packed, k_cache_packed, @@ -763,7 +817,7 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): k_real_packed, cu_seq_k_real, max_k_real = pack_sequences(k_seqs, device) v_real_packed = torch.cat(v_seqs, dim=0) - with torch.no_grad(): + with _use_backend(backend), torch.no_grad(): output_reference = varlen_attn( q_packed, k_real_packed, @@ -778,7 +832,7 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): self.assertEqual(output_cached, output_reference) # varlen_attn_out with seqused_k should match - with torch.no_grad(): + with _use_backend(backend), torch.no_grad(): out_buf = torch.empty_like(q_packed) output_out = varlen_attn_out( out_buf, @@ -811,9 +865,16 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): [127, 63, 33, 17], ], ) + @parametrize( + "backend", + ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if SM100OrLater else []), + ) def test_block_table_kv_cache( - self, device, dtype, page_size, compile, actual_kv_lens + self, device, dtype, page_size, compile, actual_kv_lens, backend ): + if backend == "fa2" and page_size % 256 != 0: + self.skipTest("FA2 paged KV requires page_size divisible by 256") + torch.manual_seed(42) batch_size = 4 @@ -860,7 +921,8 @@ def test_block_table_kv_cache( attn_fn = torch.compile(varlen_attn, fullgraph=True) if compile else varlen_attn - with torch.no_grad(): + # Reference: no block_table + with _use_backend(backend), torch.no_grad(): output_reference = varlen_attn( q_packed, k_real_packed, @@ -879,41 +941,26 @@ def test_block_table_kv_cache( dtype=torch.int32, ) - # FA2 path: paged KV with block_table (page_size % 256 == 0) - if page_size % 256 == 0: - with torch.no_grad(): - output_fa2 = attn_fn( - q_packed, - k_pages, - v_pages, - cu_seq_q, - cu_seq_k, - max_q, - cache_size, - seqused_k=seqused_k, - block_table=block_table, - ) - - self.assertEqual(output_fa2, output_reference) + # FA2 requires cu_seq_k for paged KV; FA3/FA4 pass None + cu_seq_k_paged = cu_seq_k if backend == "fa2" else None - # FA3 path: paged KV with block_table - with use_fa3(), torch.no_grad(): - output_fa3 = attn_fn( + with _use_backend(backend), torch.no_grad(): + output_paged = attn_fn( q_packed, k_pages, v_pages, cu_seq_q, - None, + cu_seq_k_paged, max_q, cache_size, seqused_k=seqused_k, block_table=block_table, ) - self.assertEqual(output_fa3, output_reference) + self.assertEqual(output_paged, output_reference) # varlen_attn_out with paged KV cache should match - with use_fa3(), torch.no_grad(): + with _use_backend(backend), torch.no_grad(): out_buf = torch.empty_like(q_packed) output_out = varlen_attn_out( out_buf, @@ -921,21 +968,21 @@ def test_block_table_kv_cache( k_pages, v_pages, cu_seq_q, - None, + cu_seq_k_paged, max_q, cache_size, seqused_k=seqused_k, block_table=block_table, ) self.assertEqual(output_out.data_ptr(), out_buf.data_ptr()) - self.assertEqual(out_buf, output_fa3) + self.assertEqual(out_buf, output_paged) - # compile the lower level aten op, will cause graph break - if compile: + # compile the lower level aten op (FA3 only, will cause graph break) + if compile and backend != "fa2": compiled_aten_op = torch.compile( torch.ops.aten._flash_attention_forward_no_dropout_inplace ) - with use_fa3(), torch.no_grad(): + with _use_backend(backend), torch.no_grad(): out_buf = torch.empty_like(q_packed) compiled_aten_op( out_buf, diff --git a/test/test_view_ops.py b/test/test_view_ops.py index e7f28615051e7..58a397fde5964 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -403,12 +403,7 @@ def fn(contiguous_input=True, dim0=0, dim1=1): @dtypesIfMPS(torch.cfloat, torch.chalf) def test_view_as_real(self, device, dtype): def fn(contiguous_input=True): - # `torch.bcomplex32` doesn't have randn yet - real_dt = torch.empty((0,), dtype=dtype, device=device).real.dtype - r = torch.randn(3, 4, dtype=real_dt, device=device) - c = torch.randn(3, 4, dtype=real_dt, device=device) - t = torch.complex(r, c) - self.assertEqual(t.dtype, dtype) + t = torch.randn(3, 4, dtype=dtype, device=device) input = self._do_transpose(t, contiguous_input) res = torch.view_as_real(input) self.assertEqual(res[:, :, 0], input.real) diff --git a/test/test_xpu.py b/test/test_xpu.py index 28ad38a9ee19b..c9055e5e0f572 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -23,6 +23,7 @@ from torch.testing import make_tensor from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_device_type import ( + dtypes, instantiate_device_type_tests, OpDTypes, ops, @@ -2772,6 +2773,78 @@ def convert_boolean_tensors(x): self.assertEqual(expect, actual) + @dtypes(torch.float16, torch.bfloat16, torch.float32) + def test_fused_rms_norm(self, device, dtype): + # Verify _fused_rms_norm is dispatched to XPU kernel (not fallback) + has_xpu_kernel = torch._C._dispatch_has_kernel_for_dispatch_key( + "aten::_fused_rms_norm", + torch._C._dispatch_key_name(torch._C.DispatchKey.XPU), + ) + self.assertTrue(has_xpu_kernel, "_fused_rms_norm XPU kernel is not registered") + has_xpu_kernel = torch._C._dispatch_has_kernel_for_dispatch_key( + "aten::_fused_rms_norm_backward", + torch._C._dispatch_key_name(torch._C.DispatchKey.XPU), + ) + self.assertTrue( + has_xpu_kernel, "_fused_rms_norm_backward XPU kernel is not registered" + ) + + shapes = [ + (2, 16), # small 2D + (4, 8, 32), # 3D + (1, 1, 64), # degenerate batch + (8, 128), # typical sequence hidden + (2, 16, 512), # typical LLM hidden dim + (4, 32, 1024), # larger hidden dim + (1, 1, 4096), # LLM-scale hidden + (3, 7, 17), # non-power-of-2 + ] + eps = 1e-5 + atol_fwd = 1e-1 if dtype in [torch.float16, torch.bfloat16] else 1e-5 + atol_bwd = 1e-1 if dtype in [torch.float16, torch.bfloat16] else 1e-5 + + for shape in shapes: + normalized_shape = list(shape[-1:]) + x = torch.randn(*shape, dtype=dtype, device=device, requires_grad=True) + w = torch.randn( + *normalized_shape, dtype=dtype, device=device, requires_grad=True + ) + grad_out = torch.randn(*shape, dtype=dtype, device=device) + x_cpu = x.detach().cpu().requires_grad_(True) + w_cpu = w.detach().cpu().requires_grad_(True) + grad_out_cpu = grad_out.detach().cpu() + + # Forward + y, _ = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, eps) + y_cpu, _ = torch.ops.aten._fused_rms_norm( + x_cpu, normalized_shape, w_cpu, eps + ) + self.assertEqual( + y, + y_cpu, + atol=atol_fwd, + rtol=0, + msg=f"forward shape={shape}, dtype={dtype}", + ) + + # Backward + y.backward(grad_out) + y_cpu.backward(grad_out_cpu) + self.assertEqual( + x.grad.cpu(), + x_cpu.grad, + atol=atol_bwd, + rtol=0, + msg=f"x_grad shape={shape}, dtype={dtype}", + ) + self.assertEqual( + w.grad.cpu(), + w_cpu.grad, + atol=atol_bwd, + rtol=0, + msg=f"w_grad shape={shape}, dtype={dtype}", + ) + instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index b8389f90d5a4f..e7218f179f730 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -d7494bcbcc7b710aed75ae84c8c53390eaa408cc +724411d3a8076dca7e62a2de30c3a5c84d6c3271 diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index e976bb07ee103..44f0360c1fb33 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -2002,7 +2002,6 @@ def replace_special_case(hint: str) -> str: "cfloat", "complex128", "cdouble", - "bcomplex32", "quint8", "qint8", "qint32", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 0bce936f884d3..d925d7e616d2f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2610,7 +2610,6 @@ def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... def _accelerator_resetPeakStats(device_index: _int) -> None: ... def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ... -def _accelerator_getAllocatorSettings() -> str: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... class _acceleratorGraph: diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 814ab243bde2a..93b89744b1a17 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -62,6 +62,7 @@ class _KinetoEvent: def duration_ns(self) -> int: ... def is_async(self) -> bool: ... def linked_correlation_id(self) -> int: ... + def external_id(self) -> int: ... def shapes(self) -> list[list[int]]: ... def dtypes(self) -> list[str]: ... def concrete_inputs(self) -> list[Any]: ... diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index a718e0511d47e..3b31f64d23d19 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -5472,6 +5472,125 @@ def resize_as(self, other, memory_format=None): return aten.resize(self, other.shape, memory_format=memory_format) +@register_decomposition(aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward( + grad_output: Tensor, + self: Tensor, + kernel_size, + stride, + padding, + dilation, + ceil_mode: bool, + indices: Tensor, +): + """ + Decomposition of max_pool2d_with_indices_backward using scatter_add. + + This replaces the native implementation with a high-level decomposition + that uses scatter_add for gradient accumulation. The scatter-based approach + provides automatic optimization opportunities for Inductor and handles all + pooling configurations without requiring specialized fallback paths. + + Algorithm: + For each output gradient position, use the corresponding index from the + forward pass to scatter the gradient to the input position. When multiple + output positions select the same input position as max, scatter_add + automatically accumulates their gradients. + + Complexity: O(B * C * H_out * W_out) + Independent of kernel size, unlike traditional O(B * C * H_in * W_in * K²) + approaches that iterate over input positions and kernel windows. + + Known Limitations: + - FP16/BF16: Uses FP32 accumulation internally to preserve precision when + many gradients accumulate to the same position (overlapping pooling windows). + This adds slight overhead but ensures numerical stability. + - Deterministic mode: Falls back to native implementation to ensure + consistent results across runs + + Args: + grad_output: Gradient w.r.t. pooling output [B, C, H_out, W_out] + self: Original input tensor (for shape) [B, C, H_in, W_in] + kernel_size: Pooling kernel size + stride: Pooling stride + padding: Pooling padding + dilation: Pooling dilation + ceil_mode: Whether to use ceil for output size calculation + indices: Indices from forward pass (per-channel linear positions) + + Returns: + Gradient w.r.t. input [B, C, H_in, W_in] + """ + # Use native kernel in deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return NotImplemented + + # MPS: Use native kernel. scatter_add has correctness issues on macOS 14 + # (#163327) and numerical differences on macOS 15+. + if grad_output.device.type == "mps": + return NotImplemented + + # Get spatial dimensions + in_height = self.size(-2) + in_width = self.size(-1) + out_height = grad_output.size(-2) + out_width = grad_output.size(-1) + + # Handle both 3D (C, H, W) and 4D (B, C, H, W) cases by treating 3D as 4D + is_batched = self.dim() == 4 + if not is_batched: + self = self.unsqueeze(0) + grad_output = grad_output.unsqueeze(0) + indices = indices.unsqueeze(0) + + batch_size = self.size(0) + channels = self.size(1) + + # For FP16/BF16, use FP32 accumulation to avoid precision loss + # This is critical when many gradients accumulate to the same position + # (overlapping pooling windows with large kernels or stride < kernel_size) + use_fp32_accum = grad_output.dtype in (torch.float16, torch.bfloat16) + accum_dtype = torch.float32 if use_fp32_accum else grad_output.dtype + + # Create grad_input with correct accumulation dtype from the start + grad_input_flat = torch.zeros( + batch_size * channels, + in_height * in_width, + dtype=accum_dtype, + device=grad_output.device, + ) + + # Reshape grad_output and indices to (B*C, H_out*W_out) + grad_output_flat = grad_output.reshape( + batch_size * channels, out_height * out_width + ) + indices_flat = indices.reshape(batch_size * channels, out_height * out_width) + + # Convert grad_output to accumulation dtype if needed + if use_fp32_accum: + grad_output_flat = grad_output_flat.to(torch.float32) + + # Scatter gradients to input positions + grad_input_flat = grad_input_flat.scatter_add(1, indices_flat, grad_output_flat) + + # Reshape back to original input shape + grad_input = grad_input_flat.reshape(batch_size, channels, in_height, in_width) + + # Convert back to original dtype if we used FP32 accumulation + if use_fp32_accum: + grad_input = grad_input.to(grad_output.dtype) + + # Preserve memory format from input (channels_last vs channels_first) + memory_format = utils.suggest_memory_format(self) + grad_input = grad_input.contiguous(memory_format=memory_format) + + # Remove batch dimension for 3D case + if not is_batched: + grad_input = grad_input.squeeze(0) + + return grad_input + + register_inplace(aten.addbmm_, aten.addbmm) register_inplace(aten.addmm_, aten.addmm) register_inplace(aten.addmv_, aten.addmv) @@ -5498,3 +5617,22 @@ def resize_as(self, other, memory_format=None): register_inplace(aten.scatter_add_, aten.scatter_add) register_inplace(aten.scatter_reduce_, aten.scatter_reduce) register_inplace(aten.silu_, aten.silu) + + +@aten.one_hot.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def one_hot(self: Tensor, num_classes: int = -1) -> Tensor: + if num_classes == -1: + num_classes = int(self.max().item()) + 1 + # _assert_async is side-effectful and won't be DCE'd + aten._assert_async.msg( + torch.all(self >= 0), + "one_hot: Class values must be non-negative.", + ) + aten._assert_async.msg( + torch.all(self < num_classes), + "one_hot: Class values must be smaller than num_classes.", + ) + return ( + self.unsqueeze(-1) + == torch.arange(num_classes, dtype=self.dtype, device=self.device) + ).to(torch.int64) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index f258338c3baf9..ddbfa358035f5 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -24,7 +24,7 @@ from .code_context import code_context from .convert_frame import replay from .decorators import ( - allow_in_graph, + allow_in_graph, # pyrefly: ignore [deprecated] assume_constant_result, disable, disable_nested_graph_breaks, diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index e6642584d7ccd..4f367c8625128 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -166,6 +166,7 @@ def invoke_subgraph_inner_compiler( from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_infer @disable + # pyrefly: ignore [deprecated] @torch._dynamo.allow_in_graph def invoke_subgraph_wrapper_unboxed(*operands: Any) -> Any: return invoke_subgraph_infer(subgraph, *operands) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index ce8eac2b2a527..ef612ea446253 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -223,6 +223,7 @@ def add( assert not hasattr(self, name) result = Op(name, fn, is_custom_function) if is_traceable: + # pyrefly: ignore [deprecated] setattr(self, name, torch._dynamo.allow_in_graph(result)) else: # C++ autograd function was not marked as traceable diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index d38105046f50a..12dec59650d2a 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -180,6 +180,26 @@ # Valid options: "dynamic", "unbacked" automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic" +# When True, adds exclusion guards for tensor dims and scalars that transition +# from static to dynamic via automatic_dynamic_shapes. +# +# Invariant: when enabled, automatic_dynamic recompilation preserves graph +# selection — inputs that matched a previous static cache entry will continue +# to use that entry, not be intercepted by a newer dynamic entry. This holds +# as long as recompilations are caused solely by the same variable being +# observed with different static values (progressive dynamism). A recompilation +# triggered by a different reason (e.g., a guard failure unrelated to shape +# transitions) will clear the exclusion state for that entry. +# +# Mechanism: the exclusion guard rejects inputs matching the prior static +# graph's sizes, so those inputs fall through to the more specialized static +# graph instead of being captured by the newer dynamic graph. +# +# Scope: applies only to graph-input-level dimension and scalar transitions. +# Does NOT handle data-dependent branching (if x.size(0) > k), graph breaks, +# or other recompilation triggers where no dimension actually transitions. +automatic_dynamic_exclusion_guard = False + # log graph in/out metadata # This is only turned on for export today since we # know we are tracing a flat callable. later, this diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index e5a8740ef7882..bd3df2c7a9224 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -193,6 +193,7 @@ def allow_in_graph(fn): # type: ignore[no-untyped-def] WARNING: this API can be a footgun, please read the documentation carefully. """ if isinstance(fn, (list, tuple)): + # pyrefly: ignore [deprecated] return [allow_in_graph(x) for x in fn] assert callable(fn), "allow_in_graph expects a callable" if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable: @@ -1407,15 +1408,21 @@ def _allow_in_graph_einops() -> None: # helpers. Backport the try/except TypeError fallback from einops 0.7.0+ # so allow_in_graph works during fake tensor validation. _patch_einops_symint_compat(einops.einops) # type: ignore[attr-defined] + # pyrefly: ignore [deprecated] allow_in_graph(einops.rearrange) + # pyrefly: ignore [deprecated] allow_in_graph(einops.reduce) if hasattr(einops, "repeat"): + # pyrefly: ignore [deprecated] allow_in_graph(einops.repeat) # available since einops 0.2.0 if hasattr(einops, "einsum"): + # pyrefly: ignore [deprecated] allow_in_graph(einops.einsum) # available since einops 0.5.0 if hasattr(einops, "pack"): + # pyrefly: ignore [deprecated] allow_in_graph(einops.pack) # available since einops 0.6.0 if hasattr(einops, "unpack"): + # pyrefly: ignore [deprecated] allow_in_graph(einops.unpack) # available since einops 0.6.0 diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 348bb22573a86..9b4f293b560be 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -605,13 +605,11 @@ def create_fx_graph_from_captured_output( graph_module = backend_input.graph_module if isinstance(root, torch.nn.Module): - graph_module._parameters = root._parameters.copy() - graph_module._buffers = root._buffers.copy() + graph_module._parameters = root._parameters + graph_module._buffers = root._buffers assert all(not hasattr(graph_module, m) for m in root._modules) graph_module._modules.update(root._modules) - graph_module._non_persistent_buffers_set = ( - root._non_persistent_buffers_set.copy() - ) + graph_module._non_persistent_buffers_set = root._non_persistent_buffers_set if sys.version_info >= (3, 14): import annotationlib # added in 3.14 diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 2649e3e10dd5e..92e5b7a368f4b 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -4248,6 +4248,26 @@ ] } ], + "GB4501": [ + { + "Gb_type": "Reconstruction of FakeIdVariable", + "Context": "str(self.value)", + "Explanation": "A fake id produced by id() on a compile-time container cannot be reconstructed across a graph break.", + "Hints": [ + "Avoid using id() on containers in code that may graph-break." + ] + } + ], + "GB4198": [ + { + "Gb_type": "Attempted to copy.deepcopy a tensor", + "Context": "copy.deepcopy({self})", + "Explanation": "Dynamo does not support copy.deepcopy() on tensors.", + "Hints": [ + "Avoid calling copy.deepcopy() on tensors inside compiled regions." + ] + } + ], "GB0344": [ { "Gb_type": "wrap_with_autocast: expected constant arg", diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index ee27fbd91947c..80bb6ddfda1cc 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -259,6 +259,10 @@ class FrameStateSizeEntry: stride: AutoDynamic | AutoUnset | tuple[int | AutoDynamic | InferStride, ...] = ( dataclasses.field(default=auto_unset) ) + excluded_sizes: tuple[int | None, ...] | None = dataclasses.field( + default=None, compare=False + ) + excluded_scalar: int | None = dataclasses.field(default=None, compare=False) def render(self) -> str: # Special cases @@ -384,8 +388,33 @@ def _merge_atom_tup( return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys)) def __ior__(self, other: Self) -> Self: + # Record current static sizes before merge. For dims that become + # dynamic, the exclusion guard will reject these values so inputs + # fall through to the earlier, more specialized cache entry. + # Already-dynamic dims become None and are ignored by the guard. + # When no dim transitions, clear stale excluded_sizes so later + # compilations don't inherit exclusions from earlier transitions. + new_size = self._merge_atom_tup(self.size, other.size) + if isinstance(self.size, tuple): + if new_size != self.size: + self.excluded_sizes = tuple( + s if type(s) is int else None for s in self.size + ) + elif self.excluded_sizes is not None: + self.excluded_sizes = None + self.size = new_size + # Same idea for scalars: record the static value about to become dynamic. + # Re-derive like excluded_sizes: only set when transitioning from a + # concrete int, clear when already dynamic. + if ( + type(self.scalar) is int + and type(other.scalar) is int + and self.scalar != other.scalar + ): + self.excluded_scalar = self.scalar + elif self.scalar is auto_dynamic and self.excluded_scalar is not None: + self.excluded_scalar = None self.scalar = self._merge_atom(self.scalar, other.scalar) - self.size = self._merge_atom_tup(self.size, other.size) self.stride = self._merge_atom_tup(self.stride, other.stride) return self diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index fb8ecf3633364..1165a4d1f2354 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -11,7 +11,8 @@ from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence from itertools import repeat as _repeat from operator import eq, ne -from typing import Any, TYPE_CHECKING, TypeVar +from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar +from typing_extensions import TypeIs import torch @@ -95,11 +96,11 @@ def radians(x: float) -> float: return math.pi / 180.0 * x -def impl_IS_MAPPING(a: object) -> bool: +def impl_IS_MAPPING(a: object) -> TypeIs[Mapping[Any, Any]]: return isinstance(a, Mapping) -def impl_MATCH_SEQUENCE(a: object) -> bool: +def impl_MATCH_SEQUENCE(a: object) -> TypeGuard[Sequence[Any]]: return isinstance(a, Sequence) and not isinstance(a, (str, bytes, bytearray)) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 0753cc2a91057..c269bf26c41e5 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -235,18 +235,13 @@ "torch.Tensor#split": TorchInGraphFunctionVariable, "torch.cuda.set_device": SkipFunctionVariable, "torch.cuda.current_device": TorchInGraphFunctionVariable, - "torch._C.autocast_decrement_nesting": SkipFunctionVariable, - "torch._C.autocast_increment_nesting": SkipFunctionVariable, "torch.autograd.grad": TorchInGraphFunctionVariable, "torch.autograd.backward": SkipFunctionVariable, - "torch._C.clear_autocast_cache": SkipFunctionVariable, "torch.distributions.constraints.is_dependent": SkipFunctionVariable, "torch.jit.isinstance": SkipFunctionVariable, "torch._C.set_anomaly_enabled": SkipFunctionVariable, - "torch._C.set_autocast_cache_enabled": SkipFunctionVariable, "torch._C.set_autocast_cpu_dtype": SkipFunctionVariable, "torch._C.set_autocast_cpu_enabled": SkipFunctionVariable, - "torch._C.set_autocast_enabled": SkipFunctionVariable, "torch._C.set_autocast_gpu_dtype": SkipFunctionVariable, "torch._C.set_autocast_ipu_dtype": SkipFunctionVariable, "torch._C.set_autocast_ipu_enabled": SkipFunctionVariable, @@ -469,7 +464,6 @@ "torch._C._accelerator_getAccelerator", "torch._C._accelerator_getDeviceIndex", "torch._C._accelerator_getStream", - "torch._C._accelerator_getAllocatorSettings", "torch._C._accelerator_setAllocatorSettings", "torch._C._accelerator_setStream", "torch._C._accelerator_synchronizeDevice", @@ -478,6 +472,8 @@ "torch._C._add_docstr", "torch._C._are_functorch_transforms_active", "torch._C._autograd_init", + "torch._C._autograd._saved_tensors_hooks_disable", + "torch._C._autograd._saved_tensors_hooks_enable", "torch._C._awaitable_nowait", "torch._C._awaitable_wait", "torch._C._awaitable", @@ -744,6 +740,7 @@ "torch._C._initExtension", "torch._C._is_alias_of", "torch._C._is_any_autocast_enabled", + "torch._C._is_autocast_available", "torch._C._is_cached_tensor", "torch._C._is_flash_attention_available", "torch._C._is_fwd_grad_enabled", @@ -1398,6 +1395,9 @@ "torch._C._xpu_resetPeakMemoryStats", "torch._C._xpu_setStream", "torch._C._xpu_synchronize", + "torch._C.autocast_decrement_nesting", + "torch._C.autocast_increment_nesting", + "torch._C.clear_autocast_cache", "torch._C.fork", "torch._C.get_autocast_cpu_dtype", "torch._C.get_autocast_dtype", @@ -1424,6 +1424,8 @@ "torch._C.parse_schema", "torch._C.parse_type_comment", "torch._C.read_vitals", + "torch._C.set_autocast_cache_enabled", + "torch._C.set_autocast_enabled", "torch._C.set_vital", "torch._C.unify_type_list", "torch._C.vitals_enabled", @@ -2222,6 +2224,7 @@ "torch.select", "torch.selu_", "torch.selu", + "torch.set_autocast_dtype", "torch.sgn", "torch.sigmoid_", "torch.sigmoid", @@ -2445,8 +2448,6 @@ "torch.accelerator.set_stream", "torch.accelerator.synchronize", "torch.align_tensors", - "torch.amp.autocast_mode._enter_autocast", - "torch.amp.autocast_mode._exit_autocast", "torch.amp.autocast_mode.autocast_decorator", "torch.amp.autocast_mode.custom_bwd", "torch.amp.autocast_mode.custom_fwd", @@ -3178,7 +3179,6 @@ def _builtin_function_ids() -> dict[int, str]: rv.update( { id(cast): "typing.cast", - id(copy.deepcopy): "copy.deepcopy", } ) return rv @@ -3324,6 +3324,13 @@ def is_numpy_type_info(obj: Any) -> bool: linecache, ) +# Builtin modules that should be skipped at the top-level (PEP 523 frame +# evaluation) but inlined when called from code dynamo is already tracing. +# For example, copy.deepcopy should be inlined when the user calls it inside +# a compiled function, but copy module frames should be skipped when they +# appear as top-level frames (e.g. called internally by autograd). +BUILTIN_INLINE_WHEN_CALLED: set[str] = set() + # third party libraries skiplist is defined by str, because users may not use these libraries. # we should use lazy import & skip in the future. THIRDPARTY_SKIPLIST = ( @@ -3636,6 +3643,8 @@ def get_mod_skiplist() -> set[str]: ] SKIP_DIRS.extend(map(_as_posix_path, filter(None, map(_module_dir, BUILTIN_SKIPLIST)))) +BUILTIN_INLINE_WHEN_CALLED.update(filter(None, (_module_dir(copy),))) + SKIP_DIRS_RE = re.compile(r"match nothing^") # Skip fbcode paths(including torch.package paths) containing @@ -3721,6 +3730,12 @@ def check_file(filename: str | None, is_inlined_call: bool = False) -> SkipResul return SkipResult(False, f"file matches LEGACY_MOD_INLINELIST ({d})") if is_inlined_call and is_torch_inline_allowed(filename): return SkipResult(False, f"file matches MOD_INLINELIST ({filename})") + if is_inlined_call and any( + filename.startswith(d) for d in BUILTIN_INLINE_WHEN_CALLED + ): + return SkipResult( + False, f"file matches BUILTIN_INLINE_WHEN_CALLED ({filename})" + ) if ( is_fbcode() and FBCODE_SKIP_DIRS diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index e90eca40ed5df..9a280517a3710 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4806,6 +4806,7 @@ def is_tensor_base_attr_getter(value: Any) -> bool: return ( isinstance(value, types.MethodWrapperType) and value.__name__ == "__get__" + and hasattr(value.__self__, "__objclass__") and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined] ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 0a6303fe8088f..2eb78823feff9 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -219,6 +219,10 @@ def is_side_effect_safe(m: MutationType) -> bool: return m.scope == scope_id +class NO_SUCH_SUBOBJ: + """Sentinel indicating no concrete Python object is available.""" + + # This helps users of `as_python_constant` to catch unimplemented error with # more information; it inherits `NotImplementedError` for backward # compatibility reasons. @@ -879,6 +883,13 @@ def get_python_hash(self) -> int: ], ) + def get_real_python_backed_value(self) -> object: + """Return the Python object this VT wraps, for `is` comparison. + + Returns NO_SUCH_SUBOBJ if no concrete Python object is available. + """ + return NO_SUCH_SUBOBJ + def is_python_equal(self, other: object) -> bool: """ NB - Deliberately not overriding the __eq__ method because that can diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5c4dff89a52ce..ce2eaaf535500 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1372,16 +1372,19 @@ def build_key_value( elif value is set_allocator: return TritonSetAllocatorVariable(value) elif isinstance(value, torch.amp.autocast_mode.autocast): - self.install_guards(GuardBuilder.ID_MATCH) - return AutocastModeVariable( - target_values=[ - value.device, - value.fast_dtype, - value._enabled, - value._cache_enabled, - ], - source=self.source, - ) + if isinstance(value, torch.amp.autocast_mode._UnmanagedAutocast): + return self.wrap_user_defined(value) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return AutocastModeVariable( + target_values=[ + value.device, + value.fast_dtype, + value._enabled, + value._cache_enabled, + ], + source=self.source, + ) elif TorchCtxManagerClassVariable.is_matching_cls(value): if inspect.isclass(value): self.install_guards(GuardBuilder.CLASS_MATCH) @@ -2673,6 +2676,7 @@ def wrap_symint( return self.tx.output.unspec_variable_map[self.name] shape_env = self.tx.output.shape_env + frame_state_entry: FrameStateSizeEntry | None = None if TracingContext.get().force_unspec_int_unbacked_size_like: wrapped_value = shape_env.create_unbacked_symint() _constrain_range_for_size(wrapped_value) @@ -2736,10 +2740,17 @@ def wrap_symint( self.install_guards(GuardBuilder.CONSTANT_MATCH) return ConstantVariable.create(value=value) + excluded_scalar = ( + frame_state_entry.excluded_scalar + if config.automatic_dynamic_exclusion_guard + and frame_state_entry is not None + else None + ) wrapped_value = shape_env.create_unspecified_symint_and_symbol( value, source=self.source, dynamic_dim=dynamic_dim, + excluded_value=excluded_scalar, ) self.tx.output.tracked_fakes.append( @@ -3483,6 +3494,7 @@ def handle_traced_output( torch._C._get_mem_efficient_sdp_enabled, torch._C._get_math_sdp_enabled, torch._C._get_overrideable_sdp_enabled, + torch._C._is_autocast_available, "is_integer", ] + list(supported_const_comparison_op_values.keys()) @@ -4015,6 +4027,7 @@ def update_dim2constraint( shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, shape_ids=getattr(e, "_dynamo_shape_ids", None), unbacked_bounds=getattr(e, "_dynamo_unbacked_bounds", None), + excluded_sizes=frame_state_entry.excluded_sizes, ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 98aa657322bdc..d0fe3c4978758 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -81,12 +81,18 @@ str_methods, tensortype_to_dtype, ) -from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker +from .base import ( + AsPythonConstantNotImplementedError, + NO_SUCH_SUBOBJ, + ValueMutationNew, + VariableTracker, +) from .constant import ( CONSTANT_VARIABLE_FALSE, CONSTANT_VARIABLE_NONE, ConstantVariable, EnumVariable, + FakeIdVariable, ) from .dicts import ( ConstDictVariable, @@ -107,7 +113,6 @@ TupleIteratorVariable, TupleVariable, ) -from .streams import EventVariable, StreamVariable from .tensor import ( FakeItemVariable, supported_comparison_ops, @@ -484,15 +489,11 @@ def _binop_handlers() -> dict[ # combinations. Handlers are attempted in order, and will be used if the type checks # match. They are expected to have the signature: # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker - from .functions import BaseUserFunctionVariable, UserFunctionVariable + from .functions import BaseUserFunctionVariable from .nn_module import NNModuleVariable from .tensor import supported_const_comparison_ops from .torch import BaseTorchVariable - from .user_defined import ( - UserDefinedClassVariable, - UserDefinedObjectVariable, - UserDefinedVariable, - ) + from .user_defined import UserDefinedVariable # Override table contains: op_fn -> [list of handlers] op_handlers: dict[Any, list[Any]] = {} @@ -825,41 +826,6 @@ def never( op_var = BuiltinVariable(op) result.extend( [ - ( - ( - (UserFunctionVariable, BuiltinVariable), - (UserFunctionVariable, BuiltinVariable), - ), - lambda tx, a, b: VariableTracker.build(tx, op(a.fn, b.fn)), - ), - ( - ( - NNModuleVariable, - NNModuleVariable, - ), - lambda tx, a, b: VariableTracker.build( - tx, - op( - tx.output.get_submodule(a.module_key), - tx.output.get_submodule(b.module_key), - ), - ), - ), - ( - (UserDefinedObjectVariable, UserDefinedObjectVariable), - compare_by_value, - ), - ( - (UserDefinedClassVariable, UserDefinedClassVariable), - compare_by_value, - ), - ( - ( - (StreamVariable, EventVariable, ConstantVariable), - (StreamVariable, EventVariable, ConstantVariable), - ), - compare_by_value, - ), ( (TensorVariable, VariableTracker), op_var._comparison_with_tensor, @@ -884,22 +850,43 @@ def handle_is( left: VariableTracker, right: VariableTracker, ) -> VariableTracker | None: - # If the two objects are of different type, we can safely return False - # and True for `is` and `is not`, respectively - if type(left) is not type(right): - return VariableTracker.build(tx, op.__name__ != "is_") + # VT identity → Python identity if left is right: - return VariableTracker.build(tx, op(left, right)) - if istype(left, variables.ObjectVariable) and istype( - right, variables.ObjectVariable - ): - return VariableTracker.build(tx, op(left.value, right.value)) + return VariableTracker.build(tx, op.__name__ == "is_") + + # Compare underlying Python objects via hook + left_val = left.get_real_python_backed_value() + right_val = right.get_real_python_backed_value() + + left_known = left_val is not NO_SUCH_SUBOBJ + right_known = right_val is not NO_SUCH_SUBOBJ + + if left_known and right_known: + result = left_val is right_val + return VariableTracker.build( + tx, result if op.__name__ == "is_" else not result + ) + + # One side has a concrete value, the other doesn't — they + # can't be identical (if they were the same object, both + # sides would resolve). + if left_known != right_known: + return VariableTracker.build(tx, op.__name__ != "is_") + + # Mutable containers created during tracing: VT identity + # = Python identity. Already False from `left is right`. + if isinstance(left, (ConstDictVariable, ListVariable)): + return VariableTracker.build(tx, op.__name__ != "is_") + + # Different exception types are never identical if ( istype(left, variables.ExceptionVariable) and istype(right, variables.ExceptionVariable) and left.exc_type is not right.exc_type ): - return VariableTracker.build(tx, op(left, right)) + return VariableTracker.build(tx, op.__name__ != "is_") + + return None result.append(((VariableTracker, VariableTracker), handle_is)) # type: ignore[arg-type] @@ -944,6 +931,9 @@ def __repr__(self) -> str: def as_python_constant(self) -> Any: return self.fn + def get_real_python_backed_value(self) -> Any: + return self.fn + def as_proxy(self) -> Any: DTYPE = { bool: torch.bool, @@ -3037,26 +3027,29 @@ def call_id( nn_mod_variable = args[0] mod = tx.output.get_submodule(nn_mod_variable.module_key) return VariableTracker.build(tx, id(mod)) - elif len(args) == 1 and isinstance( - args[0], - (variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable), - ): - if args[0].source: - if isinstance(args[0], variables.UserDefinedClassVariable): - install_guard(args[0].source.make_guard(GuardBuilder.CLASS_MATCH)) - else: - install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) - constant_result = id(args[0].value) - return VariableTracker.build(tx, constant_result) elif len(args) == 1 and args[0].is_tensor(): tensor_variable = cast(TensorVariable, args[0]) return tensor_variable.call_id(tx) - elif istype(args[0], variables.UserFunctionVariable): - return VariableTracker.build(tx, id(args[0].fn)) - elif istype(args[0], variables.SkipFunctionVariable): - return VariableTracker.build(tx, id(args[0].value)) elif istype(args[0], variables.FunctoolsPartialVariable): return VariableTracker.build(tx, id(args[0].fake_value)) + elif len(args) == 1: + arg = args[0] + if isinstance( + arg, + ( + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + if arg.source: + if isinstance(arg, variables.UserDefinedClassVariable): + install_guard(arg.source.make_guard(GuardBuilder.CLASS_MATCH)) + else: + install_guard(arg.source.make_guard(GuardBuilder.ID_MATCH)) + real_val = arg.get_real_python_backed_value() + if real_val is not NO_SUCH_SUBOBJ: + return VariableTracker.build(tx, id(real_val)) + return FakeIdVariable(id(arg)) else: unimplemented( gb_type="id() with unsupported args", diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 7eff45a42d733..bfd72e0897e9d 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -384,12 +384,65 @@ def is_python_equal(self, other: object) -> bool: and self.as_python_constant() == other.as_python_constant() ) + def get_real_python_backed_value(self) -> object: + return self.value + CONSTANT_VARIABLE_NONE = ConstantVariable(None) CONSTANT_VARIABLE_TRUE = ConstantVariable(True) CONSTANT_VARIABLE_FALSE = ConstantVariable(False) +class FakeIdVariable(VariableTracker): + """A compile-time-only id value that can be used as a dict key but cannot + be reconstructed across graph breaks. + + When dynamo evaluates ``id(x)`` on a variable tracker that has no + corresponding runtime object (e.g. a ``ConstDictVariable`` created during + tracing), we mint a fake integer id. This variable holds that id and + supports the minimal interface needed to participate as a dict key + (hashing and equality). It intentionally blocks reconstruction so that a + graph break does not silently bake a stale id into the resumed bytecode. + """ + + def __init__(self, value: int, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self) -> int: + return self.value + + def is_python_constant(self) -> bool: + return False + + def python_type(self) -> type: + return int + + def is_python_hashable(self) -> bool: + return True + + def get_python_hash(self) -> int: + return hash(self.value) + + def is_python_equal(self, other: object) -> bool: + if isinstance(other, (FakeIdVariable, ConstantVariable)): + return self.value == other.as_python_constant() + return False + + def reconstruct(self, codegen: Any) -> None: + unimplemented( + gb_type="Reconstruction of FakeIdVariable", + context=str(self.value), + explanation=( + "A fake id produced by id() on a compile-time container " + "cannot be reconstructed across a graph break." + ), + hints=[ + "Avoid using id() on containers in code that may graph-break.", + ], + ) + + class EnumVariable(VariableTracker): """VariableTracker for enum.Enum and enum.IntEnum instances @@ -429,6 +482,9 @@ def __repr__(self) -> str: def as_python_constant(self) -> enum.Enum | enum.IntEnum: return self.value + def get_real_python_backed_value(self) -> enum.Enum | enum.IntEnum: + return self.value + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if not hasattr(self.value, name): raise NotImplementedError diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 4ad9fe34e846c..a8afcd51e5ca5 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -587,6 +587,11 @@ def as_python_constant(self) -> Any: # subclasses (such as methods) usually aren't a constant return super().as_python_constant() + def get_real_python_backed_value(self) -> Any: + if istype(self, UserFunctionVariable): + return self.fn + return super().get_real_python_backed_value() + def self_args(self) -> list[VariableTracker]: return [] @@ -2109,6 +2114,9 @@ def __init__(self, value: Any, reason: str | None = None, **kwargs: Any) -> None def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + @classmethod def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": # Use closure match guard (i.e. guard on __code__ object instead of diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 10f01ca87f5e7..eccc077f24e59 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -55,6 +55,9 @@ def __repr__(self) -> str: def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + def call_function( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 03cbb774eca25..973e8ab959d5a 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -71,6 +71,7 @@ ) from .base import ( AsPythonConstantNotImplementedError, + NO_SUCH_SUBOBJ, raise_type_error_exc, VariableTracker, ) @@ -84,10 +85,6 @@ from torch._dynamo.symbolic_convert import InstructionTranslator -class NO_SUCH_SUBOBJ: - pass - - class SuperVariable(VariableTracker): _nonvar_fields = { *VariableTracker._nonvar_fields, @@ -1430,6 +1427,9 @@ def __init__(self, method_wrapper: types.MethodWrapperType, **kwargs: Any) -> No super().__init__(**kwargs) self.method_wrapper = method_wrapper + def get_real_python_backed_value(self) -> types.MethodWrapperType: + return self.method_wrapper + def call_function( self, tx: "InstructionTranslator", @@ -1567,6 +1567,9 @@ def __init__(self, desc: types.GetSetDescriptorType, **kwargs: Any) -> None: super().__init__(**kwargs) self.desc = desc + def get_real_python_backed_value(self) -> types.GetSetDescriptorType: + return self.desc + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__get__" and self.source: source = AttrSource(self.source, "__get__") @@ -1602,6 +1605,9 @@ def python_type(self) -> type[types.ModuleType]: def as_python_constant(self) -> types.ModuleType: return self.value + def get_real_python_backed_value(self) -> types.ModuleType: + return self.value + def __repr__(self) -> str: return f"PythonModuleVariable({self.value})" @@ -1677,6 +1683,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + def reconstruct(self, codegen: "PyCodegen") -> None: if not isinstance(self.value, types.GenericAlias): return super().reconstruct(codegen) @@ -1760,6 +1769,9 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def get_real_python_backed_value(self) -> Any: + return self.value + @classmethod def can_constant_fold_through(cls, fn: types.FunctionType) -> bool: mod = fn.__module__.split(".") @@ -2002,6 +2014,9 @@ def __init__(self, value: object, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def get_real_python_backed_value(self) -> object: + return self.value + def python_type(self) -> type[object]: return object @@ -2084,6 +2099,9 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def get_real_python_backed_value(self) -> Any: + return self.value + def call_function( self, tx: "InstructionTranslator", @@ -2102,6 +2120,9 @@ def __init__(self, value: logging.Logger, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def get_real_python_backed_value(self) -> logging.Logger: + return self.value + def call_method( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 2f0784ca2b9d1..857ba2c418aa9 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -215,6 +215,9 @@ def set_nn_module_stack_source(self, source: Source) -> None: def python_type(self) -> type: return self.module_type + def get_real_python_backed_value(self) -> object: + return self.value + def _wrap_submodule( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 4577fb4ea7c9f..7a861d3d7caf9 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -327,6 +327,9 @@ def __init__( def python_type(self) -> type: return torch.Stream + def get_real_python_backed_value(self) -> object: + return self.value + def call_method( self, tx: "InstructionTranslator", @@ -473,6 +476,9 @@ def __init__( self.value = value self.user_object_index = user_object_index + def get_real_python_backed_value(self) -> object: + return self.value + def call_method( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 81c183eb9493b..a602cbc875827 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -783,6 +783,16 @@ def call_method( hints=[], ) + if name == "__deepcopy__": + unimplemented( + gb_type="Attempted to copy.deepcopy a tensor", + context=f"copy.deepcopy({self})", + explanation="Dynamo does not support copy.deepcopy() on tensors.", + hints=[ + "Avoid calling copy.deepcopy() on tensors inside compiled regions.", + ], + ) + # Only override builtin tensor methods # The user can manually add override handling # with a decorator for other methods (e.g. a dispatch subclass with other methods) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 319d074d46d87..047ef88d7e4f5 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -171,6 +171,7 @@ torch.cuda.is_initialized, torch.xpu.current_device, torch.xpu.is_initialized, + torch.__future__.get_overwrite_module_params_on_conversion, ] constant_fold_functions = [ @@ -429,6 +430,9 @@ def as_proxy(self) -> Any: def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: @@ -965,6 +969,60 @@ def handle_use_deterministic_algorithms( torch._C._set_deterministic_algorithms(value) return CONSTANT_VARIABLE_NONE + @register(torch.autocast_increment_nesting) + def handle_autocast_increment_nesting( + self, tx: "InstructionTranslator" + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch.autocast_increment_nesting, (), {} + ) + prev = torch.autocast_increment_nesting() + tx.output.add_cleanup_hook(lambda: torch.autocast_decrement_nesting()) + return VariableTracker.build(tx, prev) + + @register(torch.autocast_decrement_nesting) + def handle_autocast_decrement_nesting( + self, tx: "InstructionTranslator" + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch.autocast_decrement_nesting, (), {} + ) + prev = torch.autocast_decrement_nesting() + tx.output.add_cleanup_hook(lambda: torch.autocast_increment_nesting()) + return VariableTracker.build(tx, prev) + + @register(torch.set_autocast_enabled) + def handle_set_autocast_enabled( + self, + tx: "InstructionTranslator", + device_type: VariableTracker, + enabled: VariableTracker, + ) -> VariableTracker: + tx.output.create_node( + "call_function", + torch.set_autocast_enabled, + (device_type.as_proxy(), enabled.as_proxy()), + ) + dev_py_const = device_type.as_python_constant() + prev = torch.is_autocast_enabled(dev_py_const) + torch.set_autocast_enabled(dev_py_const, enabled.as_python_constant()) + tx.output.add_cleanup_hook( + lambda: torch.set_autocast_enabled(dev_py_const, prev) + ) + return CONSTANT_VARIABLE_NONE + + @register(torch.set_autocast_cache_enabled) + def handle_set_autocast_cache_enabled( + self, tx: "InstructionTranslator", enabled: VariableTracker + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch.set_autocast_cache_enabled, (enabled.as_proxy(),) + ) + prev = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(enabled.as_python_constant()) + tx.output.add_cleanup_hook(lambda: torch.set_autocast_cache_enabled(prev)) + return CONSTANT_VARIABLE_NONE + @register(torch.are_deterministic_algorithms_enabled) def handle_are_deterministic_algorithms_enabled( self, tx: "InstructionTranslator" @@ -1921,6 +1979,52 @@ def handle_check( ), ) + def exchange_device_helper( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + fn: Callable[[int], int | None], + ) -> VariableTracker: + if len(args) != 1 or kwargs: + raise_type_error_exc( + tx, + f"{fn.__name__} takes exactly one argument ({len(args)} given)", + ) + current_device_source = CallFunctionNoArgsSource( + AttrSource(AttrSource(ImportSource("torch"), "cuda"), "current_device") + ) + install_guard(current_device_source.make_guard(GuardBuilder.EQUALS_MATCH)) + arg = args[0].as_python_constant() + prev = fn(arg) + tx.output.create_node( + "call_function", + fn, + (arg,), + {}, + ) + tx.output.add_cleanup_hook(lambda: torch.cuda.set_device(prev)) + return VariableTracker.build(tx, prev) + + @register(torch.cuda._exchange_device) + def handle_exchange_device( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return exchange_device_helper(tx, args, kwargs, torch.cuda._exchange_device) + + @register(torch.cuda._maybe_exchange_device) + def handle_maybe_exchange_device( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return exchange_device_helper( + tx, args, kwargs, torch.cuda._maybe_exchange_device + ) + @register(torch.autograd.grad) def handle_autograd_grad(self, tx: "InstructionTranslator", *args, **kwargs): """ diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index f1b535e5dc015..7e8c6447f8a9f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -96,7 +96,13 @@ tuple_methods, unpatched_nn_module_getattr, ) -from .base import MutationType, raise_type_error_exc, ValueMutationNew, VariableTracker +from .base import ( + MutationType, + NO_SUCH_SUBOBJ, + raise_type_error_exc, + ValueMutationNew, + VariableTracker, +) from .dicts import ConstDictVariable, DefaultDictVariable, SetVariable @@ -1048,6 +1054,9 @@ def is_python_equal(self, other: object) -> bool: and self.value is other.value ) + def get_real_python_backed_value(self) -> object: + return self.value + class UserDefinedExceptionClassVariable(UserDefinedClassVariable): @property @@ -1111,10 +1120,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke return super().var_getattr(tx, name) -class NO_SUCH_SUBOBJ: - pass - - class RemovableHandleClass: # Dummy class to pass to python_type of # RemovableHandleVariable @@ -1235,6 +1240,9 @@ def is_underlying_vt_modified(self, side_effects: "SideEffects") -> bool: def python_type(self) -> type: return self.value_type # type: ignore[return-value] + def get_real_python_backed_value(self) -> object: + return self.value + def as_python_constant(self) -> object: if self.is_pytree_constant_class and self.source: # NOTE pytree constants created in the torch.compile region will diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 83d7450e0fb85..2f1aae39d0a1a 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, TypeGuard import torch import torch.utils._pytree as pytree @@ -623,7 +623,7 @@ def produce_guards_and_solve_constraints( raise constraint_violation_error -def is_int(x: object) -> bool: +def is_int(x: object) -> TypeGuard[int | torch.SymInt]: return isinstance(x, int) or (isinstance(x, torch.SymInt) and x.node.expr.is_number) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 5ced53e2e123e..f5fcf7a019ff2 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -7,6 +7,7 @@ import base64 import contextlib import functools +import hashlib import json import logging import os @@ -22,7 +23,12 @@ import torch from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions -from torch._dynamo.utils import chromium_event_log_active, CompileEventLogger, counters +from torch._dynamo.utils import ( + chromium_event_log_active, + CompileEventLogger, + counters, + warn_once, +) from torch._functorch import config from torch._inductor.codecache import ( _ident, @@ -516,6 +522,107 @@ def __init__(self, gm: torch.fx.GraphModule) -> None: } ) + # pyrefly: ignore [bad-override] + def reducer_override(self, obj: Any) -> Any: + """ + Override to handle tensor subclasses (like DTensor) that aren't caught + by the dispatch_table's exact type matching. + + The dispatch_table only matches exact types, so subclasses like DTensor + fall through to the default __reduce_ex__ which includes non-deterministic + storage addresses. This method catches those cases using isinstance checks. + """ + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + # Handle tensor subclasses that aren't exactly torch.Tensor + # dispatch_table already handles torch.Tensor exactly + if isinstance(obj, torch.Tensor) and type(obj) is not torch.Tensor: + if hasattr(obj, "_stable_hash_for_caching"): + return (_ident, (obj._stable_hash_for_caching(),)) + if is_traceable_wrapper_subclass(obj): + warn_once( + f"{type(obj).__name__} does not implement _stable_hash_for_caching. " + "For PT2-compatible tensor subclasses, it is recommended to implement " + "_stable_hash_for_caching(self) -> str for stable AOT autograd caching." + ) + return (_ident, (self._default_stable_hash_for_caching(obj),)) + return self._reduce_tensor(obj) + # Return NotImplemented to fall back to default behavior + return NotImplemented + + # [NOTE] Tensor subclass stable hashing for AOT autograd cache + # Python's hash() varies with PYTHONHASHSEED, making cache keys unstable + # across processes. We use blake2b for cross-process determinism. + # + # EXTENSION POINT: Traceable wrapper subclasses can override cache key + # generation by implementing _stable_hash_for_caching(self) -> str. + # This method should return a deterministic string that uniquely identifies + # the tensor's metadata for caching purposes. See DTensor for an example. + # + # We can't define a default method on subclasses because there is no abstract + # base subclass, and we don't want to pollute torch.Tensor. Instead, we provide + # a default implementation here that uses __tensor_flatten__ to recursively + # hash inner tensors and metadata. + + def _get_stable_hash(self, obj: Any) -> str: + """ + Get stable hash for a tensor or opaque object, dispatching to custom or default implementation. + """ + from torch._opaque_base import OpaqueBase + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if hasattr(obj, "_stable_hash_for_caching"): + return obj._stable_hash_for_caching() + elif isinstance(obj, torch.Tensor) and is_traceable_wrapper_subclass(obj): + return self._default_stable_hash_for_caching(obj) + elif isinstance(obj, OpaqueBase): + # Opaque objects are runtime pass-throughs; only the type matters + # for cache key purposes, not the instance identity or value. + type_name = type(obj).__qualname__ + return hashlib.blake2b(type_name.encode(), digest_size=16).hexdigest() + elif isinstance(obj, torch.Tensor): + metadata = extract_tensor_metadata_for_cache_key(obj) + return hashlib.blake2b(pickle.dumps(metadata), digest_size=16).hexdigest() + else: + return hashlib.blake2b(pickle.dumps(obj), digest_size=16).hexdigest() + + def _default_stable_hash_for_caching(self, tensor: torch.Tensor) -> str: + """ + Default stable hash implementation for traceable wrapper subclasses. + """ + from torch._opaque_base import OpaqueBase + + inner_tensor_names, subclass_metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined] + + # Recursively get hashes of inner tensors/opaque objects + inner_hashes: dict[str, str] = {} + for name in inner_tensor_names: + inner = getattr(tensor, name) + inner_hashes[name] = self._get_stable_hash(inner) + + # Stabilize metadata: replace OpaqueBase instances with their type name + # since their repr includes memory addresses + def _stabilize(obj: Any) -> Any: + if isinstance(obj, OpaqueBase): + return type(obj).__qualname__ + if isinstance(obj, tuple): + return tuple(_stabilize(x) for x in obj) + if isinstance(obj, list): + return [_stabilize(x) for x in obj] + if isinstance(obj, dict): + return {k: _stabilize(v) for k, v in obj.items()} + return obj + + cache_data = pickle.dumps( + ( + tensor.shape, + tensor.requires_grad, + _stabilize(subclass_metadata), + inner_hashes, + ) + ) + return hashlib.blake2b(cache_data, digest_size=16).hexdigest() + def _reduce_aot_config( self, aot_config: AOTConfig ) -> tuple[Callable[..., Any], tuple[Any, ...]]: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d1ef0476d4b78..c455f7ff9c092 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -65,11 +65,13 @@ _LINKER_SCRIPT, _set_gpu_runtime_env, _TORCH_PATH, + batch_convert_cubins_to_obj, convert_cubin_to_obj, CppBuilder, CppOptions, CppTorchDeviceOptions, get_compiler_version_info, + get_cpp_compiler, get_ld_and_objcopy, get_name_and_dir_from_output_file_path, normalize_path_separator, @@ -2541,6 +2543,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: cubins_o = [] asm_files = [] if not _IS_WINDOWS: + cubins_to_embed: list[tuple[str, str]] = [] ld, objcopy = get_ld_and_objcopy(use_relative_path) kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) for kernel_name, value in CudaKernelParamCache.cache.items(): @@ -2615,10 +2618,30 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: log.info("Created multi-arch bundle: %s", cubin_file) if config.aot_inductor.embed_kernel_binary: - # Embed cubin files into model.so using objcopy - cubins_o.append( - convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + cubins_to_embed.append((cubin_file, kernel_name)) + + if cubins_to_embed: + # Batch all cubins into a single .o using .incbin assembly. + # This replaces N * 3 subprocess calls (ld + 2x objcopy per + # cubin) with a single compiler invocation. + try: + combined_obj = batch_convert_cubins_to_obj( + cubins_to_embed, + os.path.dirname(output_so), + cpp_compiler=get_cpp_compiler(), + ) + cubins_o.append(combined_obj) + except subprocess.CalledProcessError: + log.warning( + "Batched cubin embedding failed, " + "falling back to per-cubin objcopy" ) + for cubin_file, kernel_name in cubins_to_embed: + cubins_o.append( + convert_cubin_to_obj( + cubin_file, kernel_name, ld, objcopy + ) + ) output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) so_build_options = CppTorchDeviceOptions( diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index cc2c5eb7be29b..1df9b20862e47 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -43,7 +43,6 @@ torch.bool: "bool", torch.bfloat16: "at::BFloat16", torch.complex32: "at::complex", - torch.complex32: "at::complex", torch.complex64: "at::complex", torch.complex128: "at::complex", torch.float8_e4m3fn: "at::Float8_e4m3fn", @@ -69,7 +68,6 @@ torch.bool: "at::kBool", torch.bfloat16: "at::kBFloat16", torch.complex32: "at::kComplexHalf", - torch.bcomplex32: "at::kBComplex32", torch.complex64: "at::kComplexFloat", torch.complex128: "at::kComplexDouble", torch.float8_e4m3fn: "at::kFloat8_e4m3fn", diff --git a/torch/_inductor/codegen/cutlass/scheduling.py b/torch/_inductor/codegen/cutlass/scheduling.py index 74f3cd9504482..4d906eb22f243 100644 --- a/torch/_inductor/codegen/cutlass/scheduling.py +++ b/torch/_inductor/codegen/cutlass/scheduling.py @@ -2,7 +2,7 @@ import hashlib import logging from collections.abc import Sequence -from typing import cast +from typing import cast, TypeGuard from torch._inductor.codegen.cutlass.python_evt import ( CutlassEVTCodegen, @@ -53,7 +53,7 @@ def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) @staticmethod - def is_cutlass_template(node: BaseSchedulerNode) -> bool: + def is_cutlass_template(node: BaseSchedulerNode) -> TypeGuard[SchedulerNode]: return isinstance(node, SchedulerNode) and isinstance( node.node, CUTLASSTemplateBuffer ) @@ -136,7 +136,6 @@ def codegen_template( assert self.is_cutlass_template(template_node), ( "Template node passed to CUTLASSScheduling.codegen_template must be a SchedulerNode that wraps a CUTLASSTemplateBuffer" ) - template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 ctb: CUTLASSTemplateBuffer = cast(CUTLASSTemplateBuffer, template_node.node) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 9d121330b2519..449e72f99839f 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -1097,6 +1097,20 @@ def format_threads(threads: list[str], kwarg: str) -> str: arg_types=arg_types, ) + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + if V.graph.cpp_wrapper: + self.cse.generate(self.compute, f"if (!{cond}) return", assignment=False) + else: + self.headers.add("error") + self.compute.writelines( + [ + f"if (!{cond}) {{", + f" TORCH_REPORT_ERROR(error_buf, {repr(msg)});", + " return;", + "}", + ] + ) + def check_bounds( self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool ) -> None: diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index ec58e458df6b1..eb8769de8db79 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from typing import cast +from typing import cast, TypeGuard from ... import config from ...codecache import code_hash, get_path @@ -28,7 +28,7 @@ def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) @staticmethod - def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + def is_rocm_cpp_template(node: BaseSchedulerNode) -> TypeGuard[SchedulerNode]: return isinstance(node, SchedulerNode) and isinstance( node.node, ROCmTemplateBuffer ) @@ -82,7 +82,6 @@ def codegen_template( assert self.is_rocm_cpp_template(template_node), ( "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" ) - template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b9bef7f6ed9bb..853a8d7de8dda 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -5707,8 +5707,6 @@ def add_constexpr_arg(arg_name): if flops is not None: inductor_meta["kernel_flop"] = flops - triton_meta["configs"] = [config_of(signature)] - # Triton compiler includes equal_to_1 args into constants even # when they are not constexpr. otherwise there may be a segfault # during launching the Inductor-compiled Triton kernel. @@ -5724,6 +5722,14 @@ def add_constexpr_arg(arg_name): self.codegen_body() self._filter_pdl(self.body) + # Compute configs after codegen_body() so we know if the kernel + # uses atomic ops. On HIP, buffer ops don't support atomics, so + # we must not tag any args with pointer_range_32 in that case. + if torch.version.hip is not None and self.atomic_add_found: + triton_meta["configs"] = [config_of(signature, pointer_range_override=())] + else: + triton_meta["configs"] = [config_of(signature)] + for helper in self.helper_functions: code.writeline("") code.splice(helper) diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index d4adb2aaea473..8c2f9a58f02e9 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -160,6 +160,21 @@ def _decide_tl_dtype(arg): } +def _get_buffer_layout(buf_name: str) -> "torch._inductor.ir.Layout": + """Get the layout for a buffer, handling both scheduler buffers and graph inputs.""" + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.try_get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + return layout + + def is_unaligned_buffer(arg: TensorArg): buf_name = arg.buffer if buf_name in V.graph.unaligned_buffers: @@ -175,17 +190,7 @@ def is_unaligned_buffer(arg: TensorArg): # all constants are assumed to be aligned return False - if V.graph.scheduler: - layout = V.graph.scheduler.get_buffer_layout(buf_name) - else: - buffer = V.graph.try_get_buffer(buf_name) - # output arg - if not buffer: - assert buf_name == V.kernel.output_node.name - layout = V.kernel.output_node.layout - else: - layout = buffer.get_layout() - + layout = _get_buffer_layout(buf_name) if isinstance(layout, torch._inductor.ir.NonOwningLayout): return not layout.maybe_guard_aligned() else: @@ -213,10 +218,36 @@ def equal_1_arg_indices( return equal_to_1 +def _is_tensor_within_2gb(arg: TensorArg) -> bool: + """Check if a tensor argument's storage is provably within 2GB. + + Mirrors HIPBackend.is_within_2gb() but uses compile-time symbolic analysis + instead of runtime tensor inspection. This enables canonicalize_pointers to + decompose pointer arithmetic into (splat(base), offset) form for buffer ops. + """ + MAX_BYTES = 2**31 - 1 + try: + # Graph inputs aren't tracked by the scheduler; get their layout + # from the graph_inputs dict to avoid KeyError in get_buffer_layout. + if arg.buffer in V.graph.graph_inputs: + inp = V.graph.graph_inputs[arg.buffer] + if hasattr(inp, "get_layout"): + layout = inp.get_layout() + else: + return False + else: + layout = _get_buffer_layout(arg.buffer) + storage_bytes = layout.storage_size() * arg.dtype.itemsize + return V.graph.sizevars.statically_known_true(storage_bytes <= MAX_BYTES) + except Exception: + return False + + def config_of( args: list[KernelArgType], *, indices: list[int] | None = None, + pointer_range_override: tuple[int, ...] | None = None, ) -> Any: if indices is None: indices = list(range(len(args))) @@ -263,5 +294,18 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: equal_to_1 = equal_1_arg_indices(args, indices=indices) - # pyrefly: ignore [bad-argument-type] - return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) + # On AMD/HIP, tag tensor args whose storage fits in 2GB so Triton + # can use 32-bit pointer offsets and emit buffer load/store ops. + if pointer_range_override is not None: + pointer_range_32 = pointer_range_override + elif torch.version.hip is not None: + pointer_range_32 = tuple( + i + for i, arg in zip(indices, args) + if isinstance(arg, TensorArg) and _is_tensor_within_2gb(arg) + ) + else: + pointer_range_32 = () + + # pyrefly: ignore [bad-argument-count, bad-argument-type] + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1, pointer_range_32) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index abc14130afabf..5c92dd60f69bc 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1866,10 +1866,6 @@ class triton: os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION_ALLOW_MULTI_STAGES") == "1" ) - enable_tlx_templates: bool = ( - os.environ.get("TORCHINDUCTOR_ENABLE_TLX_TEMPLATES", "0") == "1" - ) - # Map for storing the amount of kernel runs with dumped input tensors # Based on hash of Triton source code to avoid bloating the folder debug_dump_kernel_inputs: dict[str, int] = {} diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index a40397c5439fb..d3330e4434971 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -437,6 +437,56 @@ def convert_cubin_to_obj( return obj_file +def batch_convert_cubins_to_obj( + cubins: list[tuple[str, str]], + output_dir: str, + cpp_compiler: str = "gcc", +) -> str: + """Convert multiple cubin files to a single .o using batched .incbin assembly. + + Instead of spawning 3 subprocesses per cubin (ld + 2x objcopy), generates + a single .S file with .incbin directives for all cubins and compiles it + with one compiler invocation. Produces bit-identical rodata and symbols + as the per-cubin convert_cubin_to_obj approach. + + Args: + cubins: list of (cubin_file_path, kernel_name) tuples. + output_dir: directory for the generated .S and .o files. + cpp_compiler: C compiler to use for assembling (default: gcc). + + Returns: + Path to the combined .o file. + """ + asm_path = os.path.join(output_dir, "cubins_combined.S") + obj_path = os.path.join(output_dir, "cubins_combined.o") + + with open(asm_path, "w") as f: + f.write(".section .rodata\n") + for cubin_file, kernel_name in cubins: + # Use absolute path to avoid issues with working directory + abs_cubin = os.path.abspath(cubin_file) + escaped_path = abs_cubin.replace("\\", "\\\\").replace('"', '\\"') + f.write( + f".balign 16\n" + f".global __{kernel_name}_start\n" + f".global __{kernel_name}_end\n" + f"__{kernel_name}_start:\n" + f'.incbin "{escaped_path}"\n' + f"__{kernel_name}_end:\n" + f".global __{kernel_name}_size\n" + f".set __{kernel_name}_size, " + f"__{kernel_name}_end - __{kernel_name}_start\n" + ) + + subprocess.run( + [cpp_compiler, "-c", asm_path, "-o", obj_path], + capture_output=True, + text=True, + check=True, + ) + return obj_path + + @functools.cache def _is_apple_clang(cpp_compiler: str) -> bool: version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") diff --git a/torch/_inductor/fx_passes/auto_chunker/applier.py b/torch/_inductor/fx_passes/auto_chunker/applier.py index 97c675ad702c2..42139a2d542cd 100644 --- a/torch/_inductor/fx_passes/auto_chunker/applier.py +++ b/torch/_inductor/fx_passes/auto_chunker/applier.py @@ -275,6 +275,22 @@ def _create_placeholder_node(input_node: Node) -> Node: ) continue + # Chunk aten.view: adjust the target shape at the chunk dimension + if ( + original_node.target == aten.view.default + and isinstance(original_node.args[0], torch.fx.Node) + and (meta := get_chunking_meta(original_node)) is not None + and meta.chunk_dim is not None + ): + shape = list(original_node.args[1]) # type: ignore[arg-type] + shape[meta.chunk_dim] = chunk_size + env[original_node] = new_graph.call_function( + aten.view.default, + (env[original_node.args[0]], shape), # type: ignore[arg-type] + original_node.kwargs, + ) + continue + # create the node with chunked inputs env[original_node] = new_graph.node_copy(original_node, lambda x: env[x]) diff --git a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py index 5d8ba6a6b74a2..e0ea9844df50c 100644 --- a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py +++ b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py @@ -136,14 +136,15 @@ def propagate_where(where_node: Node) -> bool: aten.exp.default, aten.log.default, aten.tanh.default, + aten.eq.Tensor, ] ) -def propagate_nonlinear_requires_no_scaling(out_node: Node) -> bool: +def propagate_requires_no_scaling(out_node: Node) -> bool: """ - For nonlinear ops like exp, log, tanh, scale_by cannot be propagated - through since f(S*x) != S*f(x). These ops typically appear in the chunking - subgraph when the final gradient is 1 (i.e. scale_by is None), - making scaling a no-op. + For nonlinear ops (exp, log, tanh) scale_by cannot be propagated + through since f(S*x) != S*f(x). For boolean-output ops (eq) scale_by + is meaningless. These ops only appear in the chunking subgraph when + scale_by is None (e.g. the final gradient is 1). """ args_node = get_args_of_node_type(out_node) args_meta = get_chunking_metas(args_node) @@ -165,9 +166,13 @@ def propagate_nonlinear_requires_no_scaling(out_node: Node) -> bool: aten.neg.default, aten.sum.dim_IntList, aten.sum.default, # sum to scalar + aten.amax.default, aten.mm.default, aten.permute.default, aten.expand.default, + aten.squeeze.dim, + aten.unsqueeze.default, + aten.view.default, ] ) def propagate_general_copy(out_node: Node) -> bool: diff --git a/torch/_inductor/fx_passes/auto_chunker/propagator.py b/torch/_inductor/fx_passes/auto_chunker/propagator.py index 57c89923cc9de..223a75a20d80b 100644 --- a/torch/_inductor/fx_passes/auto_chunker/propagator.py +++ b/torch/_inductor/fx_passes/auto_chunker/propagator.py @@ -1,5 +1,6 @@ import functools import logging +import math from collections.abc import Callable, Sequence from enum import Enum from queue import Queue @@ -396,6 +397,7 @@ def bwd() -> PropagateStatus: prims.fma.default, aten.where.self, aten.neg.default, + aten.eq.Tensor, ] ) def propagate_general_copy_metadata( @@ -636,6 +638,118 @@ def bwd() -> PropagateStatus: return fwd(), bwd() +@register_propagate_rule(aten.unsqueeze.default) +def propagate_unsqueeze(unsqueeze_node: Node) -> _HandlerRetType: + input_node, unsqueeze_dim = unsqueeze_node.args[:2] + assert isinstance(input_node, Node) + assert isinstance(unsqueeze_dim, int) + input_ndim = get_fake_tensor_from_node_arg(input_node).ndim # type: ignore[union-attr] + # Normalize negative dim: unsqueeze valid range is [-(ndim+1), ndim] + normalized_dim = ( + unsqueeze_dim + input_ndim + 1 if unsqueeze_dim < 0 else unsqueeze_dim + ) + + def fwd() -> PropagateStatus: + assert isinstance(input_node, Node) + input_meta = get_chunking_meta(input_node) + if input_meta is None: + return _bool_to_status(False) + if input_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(unsqueeze_node, input_meta)) + + # pyrefly: ignore[unsupported-operation] + new_dim = input_meta.chunk_dim + ( + 1 if input_meta.chunk_dim >= normalized_dim else 0 + ) + return _bool_to_status( + set_chunking_meta(unsqueeze_node, meta=input_meta, chunk_dim=new_dim) + ) + + def bwd() -> PropagateStatus: + assert isinstance(input_node, Node) + output_meta = get_chunking_meta(unsqueeze_node) + if output_meta is None: + return _bool_to_status(False) + if output_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(input_node, output_meta)) + # pyrefly: ignore[unsupported-operation] + new_dim = output_meta.chunk_dim - ( + 1 if output_meta.chunk_dim > normalized_dim else 0 + ) + return _bool_to_status( + set_chunking_meta(input_node, meta=output_meta, chunk_dim=new_dim) + ) + + return fwd(), bwd() + + +def _find_chunk_dim_after_reshape( + old_shape: Sequence[int], new_shape: Sequence[int], chunk_dim: int +) -> int | None: + """ + Find the equivalent chunk_dim position after a reshape by matching + the prefix product (number of elements before the dimension) and + the dimension size. Returns None if the chunk dimension is merged + or split by the reshape, making it unsafe to propagate. + + Examples: + [M, N] -> [M, N, 1], chunk_dim=0: returns 0 (trailing dim added) + [M] -> [M, 1], chunk_dim=0: returns 0 + [M, N] -> [M1, M2, N] where M1*M2=M, chunk_dim=0: returns None (split) + [M, N] -> [M*N], chunk_dim=0: returns None (merged) + """ + chunk_size = old_shape[chunk_dim] + old_offset = math.prod(old_shape[:chunk_dim]) + new_offset = 1 + for new_dim in range(len(new_shape)): + if new_offset == old_offset and new_shape[new_dim] == chunk_size: + return new_dim + new_offset *= new_shape[new_dim] + return None + + +@register_propagate_rule(aten.view.default) +def propagate_view(view_node: Node) -> _HandlerRetType: + input_node = view_node.args[0] + assert isinstance(input_node, Node) + input_shape = list(get_fake_tensor_from_node_arg(input_node).shape) # type: ignore[union-attr] + output_shape = list(get_fake_tensor_from_node_arg(view_node).shape) # type: ignore[union-attr] + + def fwd() -> PropagateStatus: + assert isinstance(input_node, Node) + input_meta = get_chunking_meta(input_node) + if input_meta is None: + return _bool_to_status(False) + if input_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(view_node, input_meta)) + new_dim = _find_chunk_dim_after_reshape( + input_shape, output_shape, input_meta.chunk_dim + ) + if new_dim is None: + return PropagateStatus.FAIL + return _bool_to_status( + set_chunking_meta(view_node, meta=input_meta, chunk_dim=new_dim) + ) + + def bwd() -> PropagateStatus: + assert isinstance(input_node, Node) + output_meta = get_chunking_meta(view_node) + if output_meta is None: + return _bool_to_status(False) + if output_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(input_node, output_meta)) + new_dim = _find_chunk_dim_after_reshape( + output_shape, input_shape, output_meta.chunk_dim + ) + if new_dim is None: + return PropagateStatus.FAIL + return _bool_to_status( + set_chunking_meta(input_node, meta=output_meta, chunk_dim=new_dim) + ) + + return fwd(), bwd() + + @register_propagate_rule( [ aten.expand.default, diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 017db2d471b8f..d036746a62556 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -362,12 +362,10 @@ def _deduce_value(self, node: torch.fx.Node): # handle before view ops because this changes value if node.target is aten.view.dtype: (input_tensor, output_dtype), kwargs = self.fetch_args_kwargs_from_env(node) - # view.dtype fails on 0-d tensors when element size changes - # (e.g., 0-d complex tensors can't be viewed as float) - if ( - input_tensor.ndim == 0 - and input_tensor.element_size() != output_dtype.itemsize - ): + # view.dtype with different element sizes changes element count + # (e.g., complex64 [1+0j] viewed as float32 becomes [1.0, 0.0]), + # making uniform values non-uniform. Also crashes on 0-d tensors. + if input_tensor.element_size() != output_dtype.itemsize: return self.unknown_value return super(ConstantFolder, self).run_node(node) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 5d1b33aee32e7..7b947937646e6 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1037,8 +1037,20 @@ def register_fun(cond): return register_fun +def _needs_spmd_graph_preservation() -> bool: + """Check if SPMD graph preservation is needed for distributed overlap.""" + return ( + config.aten_distributed_optimizations.enable_overlap_scheduling + or config.reorder_for_compute_comm_overlap + ) + + @register_noop_decomp(aten.slice) def slice_noop(self, dim=0, start=None, end=None, step=1): + if _needs_spmd_graph_preservation(): + # Keep no-op slices so all ranks produce identical FX graphs (SPMD) + # with matching op counts and runtime estimations. + return False if start is None or end is None: return False @@ -1082,6 +1094,10 @@ def repeat_noop(self, repeats): @register_noop_decomp(aten.constant_pad_nd) def constant_pad_nd(x, padding, fill_value=0): + if _needs_spmd_graph_preservation(): + # Keep no-op pads so all ranks produce identical FX graphs (SPMD) + # with matching op counts and runtime estimations. + return False return all(p == 0 for p in padding) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 50d14b4dc6326..f315e042dec60 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3643,11 +3643,6 @@ def loader(idx: Sequence[Expr]) -> OpsValue: class SliceView(View): - """View that represents a slice along a single dimension. - - Corresponds to tensor[..., start:end:step, ...]. - """ - @classmethod def normalize_start_end( cls, x: IRNode, dim: int, start: int, end: int @@ -3662,14 +3657,6 @@ def normalize_start_end( if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): min_func = sympy.Min max_func = sympy.Max - elif any( - # Only needed when backed_size_oblivious is on. - x.has(sympy.Min, sympy.Max) - for x in (start, end, dim_size) - if isinstance(x, Expr) - ): - min_func = sympy.Min - max_func = sympy.Max else: min_func = sizevars.evaluate_min max_func = sizevars.evaluate_max diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ca14e049c5ba4..0031cacee95e7 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -424,14 +424,6 @@ def _to_dtype(x): elif use_triton_tma_template(mat1, mat2, output_layout=layout): templates_to_use.append(persistent_tma_mm_template) - if ( - inductor_config.is_fbcode() - and inductor_config.triton.enable_tlx_templates - ): - from torch._inductor.fb.tlx_templates.mm_templates import append_tlx - - templates_to_use = append_tlx(templates_to_use) - templates_to_use.append(mm_contiguous_subgraph_template) choices.extend( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 2252e3c4dc03f..4d8ae1ff494b6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1429,12 +1429,6 @@ def compute_slice_index(index, size, default=None): return size elif fn(sympy.Lt(index, -size)): return 0 - elif fn(sympy.Ge(index, 0)): - # If index >= 0, the resolved index is at most min(index, size). - return sympy.Min(index, size) - elif fn(sympy.Lt(index, 0)): - # If index < 0, wrap and clamp: the resolved index is at least 0. - return sympy.Max(index + size, 0) return None start_index, end_index = None, None diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index a9ddf91e9a59c..af91cdece3bd3 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -47,6 +47,7 @@ class TileHint(Enum): def AttrsDescriptorWrapper( divisible_by_16=None, equal_to_1=None, + pointer_range_32=None, ): # Prepare the arguments for AttrsDescriptor kwargs = { @@ -69,6 +70,7 @@ def AttrsDescriptorWrapper( def AttrsDescriptorWrapper( divisible_by_16=None, equal_to_1=None, + pointer_range_32=None, ): # Prepare the arguments for AttrsDescriptor kwargs = { @@ -88,17 +90,27 @@ def AttrsDescriptorWrapper( def AttrsDescriptorWrapper( divisible_by_16=None, equal_to_1=None, + pointer_range_32=None, ): # pyrefly: ignore [not-iterable] - return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} + # Build attr dict merging divisibility and pointer_range per arg index, + # since a single arg can carry both attributes. + result = {(x,): [["tt.divisibility", 16]] for x in (divisible_by_16 or ())} + for x in pointer_range_32 or (): + key = (x,) + if key in result: + result[key].append(["tt.pointer_range", 32]) + else: + result[key] = [["tt.pointer_range", 32]] + return result else: # Define a namedtuple as a fallback when AttrsDescriptor is not available AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] # pyrefly: ignore [invalid-argument] "AttrsDescriptor", - ["divisible_by_16", "equal_to_1"], - defaults=[(), ()], + ["divisible_by_16", "equal_to_1", "pointer_range_32"], + defaults=[(), (), ()], ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index fe394fd90ede2..b1d8bf2616803 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3349,23 +3349,25 @@ def outer_config_opt(): ] if torch.version.hip: - # Skip large-XBLOCK HIP configs when a combo kernel has a persistent - # sub-kernel with a large hardcoded R0_BLOCK. The persistent tile size - # (XBLOCK * max_persistent_rblock) would otherwise cause pathological - # ROCm compilation times (e.g. 1024 * 1024 = 1M elements → 20+ min). - # Use the same 4096-element threshold as _persistent_reduction_configs. - max_persistent_rblock = inductor_meta.get("max_persistent_rblock", 0) hip_configs = [ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), ] + result_configs.extend(hip_configs) + + # Filter ALL configs (not just HIP-specific ones) when a combo kernel + # has a persistent sub-kernel with a large hardcoded R0_BLOCK. The + # persistent tile size (XBLOCK * max_persistent_rblock) causes + # pathological ROCm compilation times (e.g. 64 * 1024 = 64K elements + # → 60+ min triton.compile). Use the same 4096-element threshold as + # _persistent_reduction_configs. + max_persistent_rblock = inductor_meta.get("max_persistent_rblock", 0) if max_persistent_rblock > 0: - hip_configs = [ + result_configs = [ c - for c in hip_configs + for c in result_configs if c.kwargs.get("XBLOCK", 0) * max_persistent_rblock <= 4096 ] - result_configs.extend(hip_configs) return result_configs diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 2d146bca76b38..5b68096eda24f 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -361,6 +361,9 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16): return False + if MixOrderReduction.is_split_reduction(contiguous_node): + return False + # Other reduction types like max/min is not supported yet. # There are no real use case as well. out = all( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index a67bd749b3f40..1571b4560d6f5 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -549,20 +549,6 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr: if right == gcd: return right - # Min/Max fallback: we can prove Min(a, b) <= c when any arg <= c, but - # sympy doesn't simplify this yet. So, evaluate it here. Same for Max. - for lhs, rhs in [(left, right), (right, left)]: - - def le_rhs(a: Expr) -> bool: - return self.guard_or_false(sympy.Le(a, rhs)) - - # Min(Min(a, b), c) ==> Min(a, b) if (a <= c) or (b <= c). - if isinstance(lhs, sympy.Min) and any(le_rhs(a) for a in lhs.args): - return lhs - # Min(Max(a, b), c) ==> Max(a, b) if (a <= c) and (b <= c). - if isinstance(lhs, sympy.Max) and all(le_rhs(a) for a in lhs.args): - return lhs - raise TypeError( f"evaluate_min({left}, {right}) with unbacked symints" ) from None diff --git a/torch/_inductor/template_heuristics/tlx.py b/torch/_inductor/template_heuristics/tlx.py index 83381687cd5de..e75457aceb1cf 100644 --- a/torch/_inductor/template_heuristics/tlx.py +++ b/torch/_inductor/template_heuristics/tlx.py @@ -3,5 +3,3 @@ if config.is_fbcode(): import torch._inductor.fb.tlx_templates.registry # noqa: F401 # type: ignore[import-not-used] - -# TODO. Move the registry to this file once the TLX template is more complete. diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index a75048900b96f..13c65ab25c520 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -780,6 +780,12 @@ def __init__(self) -> None: for num_warps in [2, 4, 8] ] + def _get_extra_config_key_and_kwargs( + self, conf: BaseConfig + ) -> tuple[tuple[int | None, ...], dict[str, Any]]: + """Hook for subclasses to extend config dedup key and kwargs.""" + return (), {} + def _finalize_mm_configs( self, configs: list[BaseConfig], @@ -814,16 +820,8 @@ def _finalize_mm_configs( if isinstance(conf, BlackwellGPUGemmConfig): key += (conf.epilogue_subtile, conf.warp_specialize, conf.flatten) - # Add TlxGemmConfig specific fields to key if present - if config.is_fbcode() and config.triton.enable_tlx_templates: - from torch._inductor.fb.tlx_templates.registry import ( - get_tlx_config_key_and_kwargs, - ) - - tlx_key_fields, tlx_kwargs = get_tlx_config_key_and_kwargs(conf) - key += tlx_key_fields - else: - tlx_kwargs = {} + extra_key, extra_kwargs = self._get_extra_config_key_and_kwargs(conf) + key += extra_key if key not in used and ( max_mm_configs is None or len(used) < max_mm_configs @@ -844,8 +842,7 @@ def _finalize_mm_configs( kwargs["WARP_SPECIALIZE"] = conf.warp_specialize kwargs["FLATTEN"] = conf.flatten - # Add TlxGemmConfig specific fields if present - kwargs.update(tlx_kwargs) + kwargs.update(extra_kwargs) yield self.triton_config(conf.num_stages, num_warps, **kwargs) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c45cbb7724708..9227ff808ced7 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5035,7 +5035,8 @@ def unpack(name, val): return nInputPlane, outputHeight, outputWidth -@register_meta(aten.max_pool2d_with_indices_backward.default) +@register_meta(aten.max_pool2d_with_indices_backward) +@out_wrapper("grad_input") def meta_max_pool2d_with_indices_backward( grad_output, self, diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index e49f08eda873c..4705b6986eb04 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1123,13 +1123,8 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]: torch.int32, torch.int64, ) -_low_precision_dtypes = ( - torch.float16, - torch.bfloat16, - torch.complex32, - torch.bcomplex32, -) -_complex_dtypes = (torch.complex32, torch.bcomplex32, torch.complex64, torch.complex128) +_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) +_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) def is_boolean_dtype(dtype: torch.dtype) -> bool: @@ -1173,12 +1168,11 @@ def is_grad_dtype(dtype: torch.dtype) -> bool: torch.complex128: torch.float64, torch.complex64: torch.float32, torch.complex32: torch.float16, - torch.bcomplex32: torch.bfloat16, } _real_to_complex_dtype_map = { torch.float16: torch.complex32, - torch.bfloat16: torch.bcomplex32, + torch.bfloat16: torch.complex64, torch.float32: torch.complex64, torch.float64: torch.complex128, } @@ -1367,7 +1361,7 @@ def _extract_dtype( (torch.float16, torch.bfloat16), (torch.float32,), (torch.float64,), - (torch.complex32, torch.bcomplex32), + (torch.complex32,), (torch.complex64,), (torch.complex128,), ) @@ -1497,7 +1491,6 @@ def check_same_dtype(*args): torch.bfloat16: torch.float32, torch.float16: torch.float32, torch.complex32: torch.complex64, - torch.bcomplex32: torch.complex64, } @@ -1605,7 +1598,7 @@ def elementwise_dtypes( partially ordered as follows: bool -> uint8, int8 -> int16 -> int32 -> int64 -> - float16, bfloat16 -> float32 -> float64 -> complex32, bcomplex32 -> complex64 -> complex128 + float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 The result dtype is selected by: - if no tensor's dtype has the same corresponding type as the one selected, @@ -1625,7 +1618,7 @@ def elementwise_dtypes( The "corresponding complex dtypes" are: float16 -> complex32 - bfloat16 -> bcomplex32 + bfloat16 -> complex64 float32 -> complex64 float64 -> complex128 complex32 -> complex32 @@ -1636,7 +1629,7 @@ def elementwise_dtypes( dtype by mapping low precision floating point and complex dtypes as follows: float16 -> float32 - bfloat16 -> bcomplex32 + bfloat16 -> float32 complex32 -> complex64 This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index b26d554d051bc..7c089a68a57ec 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -50,6 +50,7 @@ elementwise_unary_scalar_wrapper, out_wrapper, ) +from torch.testing._internal.common_dtype import highest_precision_float # Experimental module containing prototype Python references for existing @@ -5469,13 +5470,13 @@ def linspace( start.dim() == 0, lambda: "linspace only supports 0-dimensional start and end tensors", ) - start = _maybe_convert_to_dtype(start, torch.float64) + start = _maybe_convert_to_dtype(start, highest_precision_float(device)) if isinstance(end, TensorLikeType): torch._check( end.dim() == 0, lambda: "linspace only supports 0-dimensional start and end tensors", ) - end = _maybe_convert_to_dtype(end, torch.float64) + end = _maybe_convert_to_dtype(end, highest_precision_float(device)) if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): default_complex_dtype = utils.corresponding_complex_dtype( diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 76aebe6d45783..a602e266ba11e 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -961,11 +961,6 @@ def _compute_slice_index(size: IntLikeType, index: IntLikeType) -> IntLikeType | return 0 elif guard_or_false(index > size): return size - elif guard_or_false(index >= 0): - return torch.sym_min(index, size) - elif guard_or_false(index < 0): - return torch.sym_max(index + size, 0) - return None @@ -1010,12 +1005,6 @@ def slice_forward( new_size = (end_index - start_index + step - 1) // step elif guard_or_false(start_index >= end_index): new_size = 0 - else: - # Both indices are resolved but we can't statically determine their - # ordering (e.g., when they involve Min/Max). Compute the size via - # max(end - start, 0) to avoid creating an unbacked symint. - diff = torch.sym_max(end_index - start_index, 0) - new_size = (diff + step - 1) // step # create unbacked if case unknown if new_size is None: diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 529e47a382b7b..892b9eeef646c 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -392,6 +392,11 @@ def __call__(self, func): return autocast_decorator(self, func) +# Subclass to distinguish autocast variables created by _enter_autocast (and not managed by a with statement) +class _UnmanagedAutocast(autocast): + pass + + # These functions aren't meant for public usage. # They are what we trace into a graph during pre_dispatch tracing # when we encounter an autocast context manager. @@ -401,7 +406,7 @@ def _enter_autocast(*vals): return torch.overrides.handle_torch_function( torch.amp._enter_autocast, [], *vals ) - mode = torch.amp.autocast(*vals) + mode = _UnmanagedAutocast(*vals) mode.__enter__() return mode diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index a7346daa283a5..95fc679d79ae4 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -3,6 +3,7 @@ import warnings from collections import namedtuple from typing import Any +from typing_extensions import TypeIs import torch import torch.ao.nn.intrinsic as nni @@ -327,7 +328,9 @@ def node_supports_equalization(node: Node, modules) -> bool: return False -def is_equalization_observer(observer: nn.Module) -> bool: +def is_equalization_observer( + observer: nn.Module, +) -> TypeIs[_InputEqualizationObserver | _WeightEqualizationObserver]: return isinstance( observer, (_InputEqualizationObserver, _WeightEqualizationObserver) ) diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 3d1c2a7269338..0775d43994283 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -376,11 +376,15 @@ class _force_original_view_tracking(_DecoratorContextManager): def __init__(self, mode: bool) -> None: self.prev = torch._C._is_view_replay_enabled() - torch._C._set_view_replay_enabled(mode) self.mode = mode + torch._C._set_view_replay_enabled(mode) + + def __call__(self, orig_func: F) -> F: + torch._C._set_view_replay_enabled(self.prev) + return super().__call__(orig_func) def __enter__(self) -> None: - pass + torch._C._set_view_replay_enabled(self.mode) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch._C._set_view_replay_enabled(self.prev) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 37a4aaa1d8ead..07c332eef9f22 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -698,6 +698,8 @@ def _device_memory_usage(mem_record): flops=kineto_event.flops(), is_user_annotation=kineto_event.is_user_annotation(), metadata_json=kineto_event.metadata_json(), + external_id=kineto_event.external_id(), + linked_correlation_id=kineto_event.linked_correlation_id(), ) max_evt_id = max(max_evt_id, fe.id) if fe.device_type == DeviceType.CPU and not fe.is_async: diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index c51ac870aa37b..02fd0e35e9285 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -638,6 +638,8 @@ def __init__( kwinputs=None, is_user_annotation=False, metadata_json=None, + external_id=0, + linked_correlation_id=0, ): self.id: int = id self.node_id: int = node_id @@ -680,6 +682,8 @@ def __init__( self.total_cpu_percent = -1 self.total_device_percent = -1 self.metadata_json = metadata_json + self.external_id: int = external_id + self.linked_correlation_id: int = linked_correlation_id def append_kernel(self, name, device, duration): if self.device_type != DeviceType.CPU: diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 9e5aa1e5eaa69..9281bf2608d75 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -165,10 +165,6 @@ void initModule(PyObject* module) { return at::accelerator::getMemoryInfo(device_index); }); - m.def("_accelerator_getAllocatorSettings", []() { - return c10::CachingAllocator::getAllocatorSettings(); - }); - m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index c5b1ad5253b2f..1641766b1b6d5 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -989,6 +989,7 @@ Tensor unbind_backward_nested( int64_t dim, const at::TensorOptions& options) { std::vector grads_tensors; + grads_tensors.reserve(grads.size()); for (int64_t i : c10::irange(static_cast(grads.size()))) { if (grads[i].defined()) { grads_tensors.push_back(static_cast(grads[i])); @@ -2185,6 +2186,7 @@ Tensor _nested_split_with_sizes_backward( // it's possible some of the grads are not defined (represents tensors of all // 0s). Since at::cat can't handle those, let's define them std::vector grads_all_defined; + grads_all_defined.reserve(grads.size()); for (int64_t i : c10::irange(static_cast(grads.size()))) { if (grads[i].defined()) { grads_all_defined.push_back(static_cast(grads[i])); @@ -5347,7 +5349,7 @@ Tensor _cudnn_ctc_loss_backward( bool zero_infinity) { if (zero_infinity) { return at::where( - loss.unsqueeze(0).unsqueeze(2).isinf(), + loss.unsqueeze(0).unsqueeze(2) == 0, at::zeros({}, raw_grad.options()), raw_grad * grad_out.unsqueeze(0).unsqueeze(2)); } else { diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 3e26190dd9ec6..360d6c70867c7 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -293,6 +293,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .def( "linked_correlation_id", [](const KinetoEvent& e) { return e.linkedCorrelationId(); }) + .def("external_id", [](const KinetoEvent& e) { return e.externalId(); }) // compute flops .def("flops", [](const KinetoEvent& e) { return e.flops(); }) // Whether this is async event or not diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 059fe43defb44..9d4f5f043ea21 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -1107,6 +1107,40 @@ std::string KinetoEvent::metadataJson() const { [](const auto&) -> std::string { return std::string(""); })); } +int64_t KinetoEvent::externalId() const { + // Mirrors libkineto::ChromeTraceLogger::handleActivity() "External id" logic. + // libkineto::ChromeTraceLogger checks op.linkedActivity() != nullptr; here we + // check linkedCorrelationId() > 0, which is equivalent because PyTorch + // correlation IDs are monotonically increasing from 1 (a valid linked + // activity always has a non-zero correlation ID). + uint64_t linked = linkedCorrelationId(); + if (linked > 0) { + return static_cast(linked); + } + + // Orphaned GPU activities (no linked CPU op) in these types should not get + // an External id, to avoid incorrect cross-linking in trace viewers. + auto type = static_cast(activityType()); + if (type != libkineto::ActivityType::GPU_MEMCPY && + type != libkineto::ActivityType::GPU_MEMSET && + type != libkineto::ActivityType::CONCURRENT_KERNEL && + type != libkineto::ActivityType::CUDA_RUNTIME && + type != libkineto::ActivityType::CUDA_DRIVER && + type != libkineto::ActivityType::PRIVATEUSE1_RUNTIME && + type != libkineto::ActivityType::PRIVATEUSE1_DRIVER) { + return static_cast(result_->visit(c10::overloaded( + [](const ExtraFields& e) -> uint64_t { + return e.correlation_id_; + }, + [](const ExtraFields& e) -> uint64_t { + return e.correlation_id_; + }, + [](const auto&) -> uint64_t { return 0; }))); + } + + return 0; +} + #define FORWARD_FROM_RESULT(method_name, result_expr) \ decltype(std::declval().method_name()) \ KinetoEvent::method_name() const { \ diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index ab0b792716eeb..e03008b2b758c 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -58,6 +58,7 @@ struct TORCH_API KinetoEvent { bool isAsync() const; uint64_t correlationId() const; uint64_t linkedCorrelationId() const; + int64_t externalId() const; int64_t deviceResourceId() const; std::string backend() const; bool isPythonFunction() const; diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 2b97fd593cfe4..f2023fea5bd6c 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -1097,6 +1097,7 @@ void _trace_post_record( } std::vector trace_outputs; + trace_outputs.reserve(static_cast(std::max(0, num_outputs))); for (const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GET_ITEM(output_objects, i); if (THPVariable_Check(obj)) { diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index e09883e27f1ef..e21a9ccbe5635 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2724,9 +2724,7 @@ void stop_recording_dict_pointers( bool is_recording_dict_pointers(RootGuardManager* root); void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer); void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer); -void record_tensor_requires_grad( - RootGuardManager* root, - PyObject* tensor_pointer); +void record_tensor_metadata(RootGuardManager* root, PyObject* tensor_pointer); GuardManager* clone_guard_manager( GuardManager* from, @@ -2737,12 +2735,53 @@ void add_relational_guard_resetter_to_cloned_root( std::shared_ptr guard); std::shared_ptr get_no_tensor_aliasing_guard( RootGuardManager* _root); +const LocalState& get_local_state(RootGuardManager* root); // std::string get_compile_id(RootGuardManager* root); struct WeakEntry { PyObject* wr; // weakref PyObject* cap; // capsule whose m_self is used by the callback }; + +// Convert concrete sizes/strides to the optional vectors that +// TensorCheck expects. All dimensions are treated as static (no nullopt). +inline std::vector> to_opt_symint( + c10::IntArrayRef vals) { + std::vector> out; + out.reserve(vals.size()); + for (auto v : vals) { + out.emplace_back(c10::SymInt(v)); + } + return out; +} + +// Build a TensorCheck that validates all concrete metadata (dispatch key, +// dtype, device, requires_grad, sizes, strides) for the dict-tag fast path. +inline TensorCheck make_tensor_check( + const LocalState& state, + const at::Tensor& tensor) { + auto layout = tensor.layout(); + bool sparse = layout == c10::kSparseCsr || layout == c10::kSparseCsc || + layout == c10::kSparseBsc || layout == c10::kSparseBsr; + // Sparse layouts don't support strides; use nullopt per dim so + // TensorCheck skips stride comparison for each dimension. + auto strides = sparse + ? std::vector>(tensor.dim(), std::nullopt) + : to_opt_symint(tensor.strides()); + return TensorCheck( + state, + /*pt=*/nullptr, + tensor, + tensor.key_set(), + to_opt_symint(tensor.sizes()), + std::move(strides)); +} + +struct RecordedTensorMetadata { + PyObject* tensor_ptr; + TensorCheck check; +}; + /** * Base class representing a pair of accessor and the associated guard * manager. The accessor defines how to access the child value from the @@ -3012,10 +3051,10 @@ class GuardManager { _tensor_pointers[value] = tensor_pointers; } - void stash_tensor_requires_grad( + void stash_tensor_metadata( PyObject* value, - std::vector>&& tensor_requires_grad) { - _tensor_requires_grad_pointers[value] = std::move(tensor_requires_grad); + std::vector&& tensor_metadata) { + _tensor_metadata_pointers[value] = std::move(tensor_metadata); } void disable_recursive_dict_tag_optimization() { @@ -3144,15 +3183,16 @@ class GuardManager { return true; } - bool check_tensor_requires_grad_fast(PyObject* value) const { - auto it = _tensor_requires_grad_pointers.find(value); - if (it == _tensor_requires_grad_pointers.end()) { + bool check_tensor_metadata_fast(PyObject* value) { + auto it = _tensor_metadata_pointers.find(value); + if (it == _tensor_metadata_pointers.end()) { return true; } - for (const auto& [tensor_ptr, expected_requires_grad] : it->second) { - if (THPVariable_Check(tensor_ptr) && - THPVariable_Unpack(tensor_ptr).requires_grad() != - expected_requires_grad) { + for (auto& recorded_tensor : it->second) { + if (!THPVariable_Check(recorded_tensor.tensor_ptr) || + !recorded_tensor.check.check( + get_local_state(_root), + THPVariable_Unpack(recorded_tensor.tensor_ptr))) { return false; } } @@ -3185,15 +3225,17 @@ class GuardManager { // For a `tag_safe_root`, the input pointer called `value`, the object the // guard is inspecting, serves as a proxy for the entire nested dictionary // structure beneath that node. If this `value` pointer is one we have - // already recorded, then verifying each dictionary’s tag is sufficient to - // prove that nothing inside the subtree has changed. + // already recorded, then verifying each dictionary’s tag plus the cached + // tensor metadata is sufficient to prove that nothing inside the subtree + // has changed. // // Runtime flow // ------------- // 1) Previously‑seen `value` pointer // • Look up the current `value` pointer in our cache. - // • If found, perform a recursive tag comparison on the cached subtree. - // All tags match means guard passes with no further traversal. + // • If found, perform a recursive tag comparison on the cached subtree + // and revalidate recorded tensor metadata. + // All checks passing means guard passes with no further traversal. // // 2) First‑time `value` pointer // • Enter recording mode; walk the subtree, each tag safe root collects @@ -3236,7 +3278,7 @@ class GuardManager { // Check for fast path // if (is_weakref_valid(value) && check_dict_pointer_tags(value)) { if (check_dict_pointer_tags(value) && - check_tensor_requires_grad_fast(value)) { + check_tensor_metadata_fast(value)) { if (check_no_tensor_aliasing_guards_fast(value)) { return true; } else { @@ -3266,9 +3308,9 @@ class GuardManager { } else if (_has_no_tensor_aliasing_guard) { record_tensor_pointer(_root, value); } - // Record tensor requires_grad for all tensors in the subtree. + // Tensor metadata can mutate in-place without changing dict tags. if (_is_immutable && THPVariable_Check(value)) { - record_tensor_requires_grad(_root, value); + record_tensor_metadata(_root, value); } } } @@ -3674,8 +3716,8 @@ class GuardManager { std::unordered_map>> _dict_pointers; std::unordered_map> _tensor_pointers; - std::unordered_map>> - _tensor_requires_grad_pointers; + std::unordered_map> + _tensor_metadata_pointers; std::vector _tag_safe_entries; // 3.12+ related helper @@ -3944,7 +3986,7 @@ class RootGuardManager : public GuardManager { _current_tag_safe_root = nullptr; _recorded_dict_pointers.clear(); _recorded_tensor_pointers.clear(); - _recorded_tensor_requires_grad.clear(); + _recorded_tensor_metadata.clear(); } void stop_recording_dict_pointers(PyObject* value, bool result) { @@ -3954,8 +3996,8 @@ class RootGuardManager : public GuardManager { value, _recorded_dict_pointers); _current_tag_safe_root->stash_tensor_pointers( value, _recorded_tensor_pointers); - _current_tag_safe_root->stash_tensor_requires_grad( - value, std::move(_recorded_tensor_requires_grad)); + _current_tag_safe_root->stash_tensor_metadata( + value, std::move(_recorded_tensor_metadata)); } reset_dict_tag_recording_variables(); } @@ -3973,9 +4015,11 @@ class RootGuardManager : public GuardManager { _recorded_tensor_pointers.push_back(tensor_pointer); } - void record_tensor_requires_grad(PyObject* tensor_pointer) { - bool rg = THPVariable_Unpack(tensor_pointer).requires_grad(); - _recorded_tensor_requires_grad.emplace_back(tensor_pointer, rg); + void record_tensor_metadata(PyObject* tensor_pointer) { + _recorded_tensor_metadata.push_back(RecordedTensorMetadata{ + tensor_pointer, + make_tensor_check(_local_state, THPVariable_Unpack(tensor_pointer)), + }); } public: @@ -4034,7 +4078,7 @@ class RootGuardManager : public GuardManager { GuardManager* _current_tag_safe_root{nullptr}; std::vector> _recorded_dict_pointers; std::vector _recorded_tensor_pointers; - std::vector> _recorded_tensor_requires_grad; + std::vector _recorded_tensor_metadata; }; /* @@ -4481,10 +4525,8 @@ void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer) { root->record_tensor_pointer(tensor_pointer); } -void record_tensor_requires_grad( - RootGuardManager* root, - PyObject* tensor_pointer) { - root->record_tensor_requires_grad(tensor_pointer); +void record_tensor_metadata(RootGuardManager* root, PyObject* tensor_pointer) { + root->record_tensor_metadata(tensor_pointer); } std::shared_ptr get_no_tensor_aliasing_guard( @@ -4492,6 +4534,10 @@ std::shared_ptr get_no_tensor_aliasing_guard( return _root->get_no_tensor_aliasing_guard(); } +const LocalState& get_local_state(RootGuardManager* root) { + return root->_local_state; +} + // std::string get_compile_id(RootGuardManager* root) { // return root->get_compile_id(); // } @@ -5003,10 +5049,7 @@ class FrameLocalsGuardAccessor : public GuardAccessor { _key(key[0].ptr()), _framelocals_idx(key[1].cast()), _is_immutable_object(is_immutable_object(example_value)), - _is_tensor(THPVariable_Check(example_value.ptr())), - _tensor_requires_grad( - _is_tensor ? THPVariable_Unpack(example_value.ptr()).requires_grad() - : false) {} + _is_tensor(THPVariable_Check(example_value.ptr())) {} // Run as a result of calling run_root_guard_manager/check_nopybind // NB: Intentional duplication between check_nopybind and @@ -5014,18 +5057,8 @@ class FrameLocalsGuardAccessor : public GuardAccessor { bool check_nopybind( FrameLocalsMapping* obj, bool matches_dict_tag = false) override { // borrowed ref - if (matches_dict_tag && _is_immutable_object) { - // Tensors are treated as immutable for the dict-tag optimization, but - // their metadata (e.g. requires_grad) can be mutated in-place without - // changing the parent dict's version tag. For now we only check - // requires_grad since it is the most common mutation; other metadata - // changes (dtype, device, etc.) are possible but rare in practice. - if (!_is_tensor) { - return true; - } - if (!tensor_requires_grad_changed(obj->get(_framelocals_idx))) { - return true; - } + if (matches_dict_tag && _is_immutable_object && !_is_tensor) { + return true; } PyObject* x = obj->get(_framelocals_idx); @@ -5047,19 +5080,8 @@ class FrameLocalsGuardAccessor : public GuardAccessor { PyDict_Check(obj), "FrameLocalsGuardAccessor check expected dict() input"); - if (matches_dict_tag && _is_immutable_object) { - // Tensors are treated as immutable for the dict-tag optimization, but - // their metadata (e.g. requires_grad) can be mutated in-place without - // changing the parent dict's version tag. For now we only check - // requires_grad since it is the most common mutation; other metadata - // changes (dtype, device, etc.) are possible but rare in practice. - if (!_is_tensor) { - return true; - } - PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref - if (!tensor_requires_grad_changed(x)) { - return true; - } + if (matches_dict_tag && _is_immutable_object && !_is_tensor) { + return true; } PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref @@ -5116,15 +5138,9 @@ class FrameLocalsGuardAccessor : public GuardAccessor { to->_framelocals_idx = _framelocals_idx; to->_is_immutable_object = _is_immutable_object; to->_is_tensor = _is_tensor; - to->_tensor_requires_grad = _tensor_requires_grad; } private: - bool tensor_requires_grad_changed(PyObject* x) const { - return x != nullptr && THPVariable_Check(x) && - THPVariable_Unpack(x).requires_grad() != _tensor_requires_grad; - } - PyObject* _key{nullptr}; int _framelocals_idx{-1}; @@ -5132,7 +5148,6 @@ class FrameLocalsGuardAccessor : public GuardAccessor { // return true. bool _is_immutable_object{false}; bool _is_tensor{false}; - bool _tensor_requires_grad{false}; }; /** @@ -5156,30 +5171,15 @@ class DictGetItemGuardAccessor : public GuardAccessor { guard_manager_enum), _key(key.ptr()), _is_immutable_object(is_immutable_object(example_value)), - _is_tensor(THPVariable_Check(example_value.ptr())), - _tensor_requires_grad( - _is_tensor ? THPVariable_Unpack(example_value.ptr()).requires_grad() - : false) {} + _is_tensor(THPVariable_Check(example_value.ptr())) {} // NB: Intentional duplication between check_nopybind and // check_verbose_nopybind. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { - if (matches_dict_tag && _is_immutable_object && + if (matches_dict_tag && _is_immutable_object && !_is_tensor && !is_recording_dict_pointers(get_guard_manager()->get_root()) && _guard_manager->has_no_accessors()) { - // Tensors are treated as immutable for the dict-tag optimization, but - // their metadata (e.g. requires_grad) can be mutated in-place without - // changing the parent dict's version tag. For now we only check - // requires_grad since it is the most common mutation; other metadata - // changes (dtype, device, etc.) are possible but rare in practice. - if (!_is_tensor) { - return true; - } - PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref - if (!tensor_requires_grad_changed(x)) { - return true; - } - // Fall through to full check - requires_grad changed. + return true; } PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref @@ -5226,22 +5226,15 @@ class DictGetItemGuardAccessor : public GuardAccessor { to->_key = _key; to->_is_immutable_object = _is_immutable_object; to->_is_tensor = _is_tensor; - to->_tensor_requires_grad = _tensor_requires_grad; } private: - bool tensor_requires_grad_changed(PyObject* x) const { - return x != nullptr && THPVariable_Check(x) && - THPVariable_Unpack(x).requires_grad() != _tensor_requires_grad; - } - PyObject* _key{nullptr}; // If immutable object and dict tag matches, we can skip the guard subtree and // return true. bool _is_immutable_object{false}; bool _is_tensor{false}; - bool _tensor_requires_grad{false}; }; /** diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index 0d1e7e0604222..fa232e1a01016 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -61,6 +61,7 @@ const ActivityTypeMap kMtiaTypes{ {libkineto::ActivityType::MTIA_CCP_EVENTS, "MTIA_CCP_EVENTS"}, {libkineto::ActivityType::MTIA_RUNTIME, "MTIA_RUNTIME"}, {libkineto::ActivityType::MTIA_INSIGHT, "MTIA_INSIGHT"}, + {libkineto::ActivityType::MTIA_COUNTERS, "MTIA_COUNTERS"}, }; const ActivityTypeMap kHpuTypes{ @@ -356,6 +357,12 @@ void prepareTrace( } else { LOG(INFO) << "Disabling MTIA insight events"; } + if (config.custom_profiler_config.find("disable_counter_events") == + std::string::npos) { + k_activities.insert(libkineto::ActivityType::MTIA_COUNTERS); + } else { + LOG(INFO) << "Disabling MTIA counter events"; + } } } if (activities.count(torch::autograd::profiler::ActivityType::HPU)) { @@ -490,6 +497,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { // TODO: T151322015 case libkineto::ActivityType::MTIA_CCP_EVENTS: case libkineto::ActivityType::MTIA_INSIGHT: + case libkineto::ActivityType::MTIA_COUNTERS: return device_type_privateuse1_or(c10::DeviceType::MTIA); case libkineto::ActivityType::HPU_OP: return c10::DeviceType::HPU; diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 2f3d1fce740ce..379c065b22b07 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -486,7 +486,10 @@ convertIValue( itemsize, device_str); return std::make_tuple( - tensor_shape, tensor_stride, tensor_type, tensor_value); + std::move(tensor_shape), + std::move(tensor_stride), + std::move(tensor_type), + std::move(tensor_value)); } else if (val.isTuple()) { const auto& val_tuple = val.toTupleRef().elements(); size_t tuple_size = val_tuple.size(); @@ -494,6 +497,10 @@ convertIValue( std::vector stride_array; std::vector type_array; std::vector value_array; + shape_array.reserve(tuple_size); + stride_array.reserve(tuple_size); + type_array.reserve(tuple_size); + value_array.reserve(tuple_size); for (const auto j : c10::irange(tuple_size)) { auto tuple = convertIValue( ob, @@ -505,17 +512,17 @@ convertIValue( val_tuple[j], false, maxArrayLen); - shape_array.push_back(std::get<0>(tuple)); - stride_array.push_back(std::get<1>(tuple)); - type_array.push_back(std::get<2>(tuple)); - value_array.push_back(std::get<3>(tuple)); + shape_array.push_back(std::move(std::get<0>(tuple))); + stride_array.push_back(std::move(std::get<1>(tuple))); + type_array.push_back(std::move(std::get<2>(tuple))); + value_array.push_back(std::move(std::get<3>(tuple))); } type = type + vectorToString(type_array); std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type; return std::make_tuple( vectorToString(shape_array), vectorToString(stride_array), - tensor_type, + std::move(tensor_type), vectorToString(value_array)); } else if (val.isList()) { const auto& val_list = val.toList(); @@ -524,6 +531,11 @@ convertIValue( std::vector stride_array; std::vector type_array; std::vector value_array; + const size_t effective_list_size = std::min(list_size, maxArrayLen + 1); + shape_array.reserve(effective_list_size); + stride_array.reserve(effective_list_size); + type_array.reserve(effective_list_size); + value_array.reserve(effective_list_size); for (const auto j : c10::irange(list_size)) { auto tuple = convertIValue( ob, @@ -535,10 +547,10 @@ convertIValue( val_list.get(j), false, maxArrayLen); - shape_array.push_back(std::get<0>(tuple)); - stride_array.push_back(std::get<1>(tuple)); - type_array.push_back(std::get<2>(tuple)); - value_array.push_back(std::get<3>(tuple)); + shape_array.push_back(std::move(std::get<0>(tuple))); + stride_array.push_back(std::move(std::get<1>(tuple))); + type_array.push_back(std::move(std::get<2>(tuple))); + value_array.push_back(std::move(std::get<3>(tuple))); if (j >= maxArrayLen) { LOG(WARNING) << "list size=" << val_list.size() << " exceeded maxArrayLen=" << maxArrayLen; @@ -550,7 +562,7 @@ convertIValue( return std::make_tuple( vectorToString(shape_array), vectorToString(stride_array), - tensor_type, + std::move(tensor_type), vectorToString(value_array)); } else { std::string tensor_shape = "[]"; @@ -559,7 +571,10 @@ convertIValue( std::string tensor_value = getScalarValue(val); return std::make_tuple( - tensor_shape, tensor_stride, tensor_type, tensor_value); + std::move(tensor_shape), + std::move(tensor_stride), + std::move(tensor_type), + std::move(tensor_value)); } } diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index a6c91c8bd14a5..66e92ba8fd52d 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -61,11 +61,6 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { (c10::complex)static_cast>( THPUtils_unpackComplexDouble(obj)); break; - case at::kBComplex32: - *(c10::complex*)data = - (c10::complex)static_cast>( - THPUtils_unpackComplexDouble(obj)); - break; case at::kComplexFloat: *(c10::complex*)data = (c10::complex)THPUtils_unpackComplexDouble(obj); @@ -135,10 +130,6 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { auto data_ = reinterpret_cast*>(data); return PyComplex_FromDoubles(data_->real(), data_->imag()); } - case at::kBComplex32: { - auto data_ = reinterpret_cast*>(data); - return PyComplex_FromDoubles(data_->real(), data_->imag()); - } case at::kComplexFloat: { auto data_ = reinterpret_cast*>(data); return PyComplex_FromDoubles(data_->real(), data_->imag()); diff --git a/torch/distributed/_composable/replicate_with_fsdp.py b/torch/distributed/_composable/replicate_with_fsdp.py index ba8a54452058f..e0c9125fca52d 100644 --- a/torch/distributed/_composable/replicate_with_fsdp.py +++ b/torch/distributed/_composable/replicate_with_fsdp.py @@ -15,9 +15,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_init import ( _apply_to_module, _get_device_from_mesh, + _get_mesh_info, _get_modules_and_states, _init_default_mesh, _init_param_group, + _validate_mesh as _validate_mesh_common, _validate_module as _validate_module_common, ) from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState, FSDPStateContext @@ -30,6 +32,7 @@ if TYPE_CHECKING: + from torch.distributed.fsdp._fully_shard._fsdp_api import DataParallelMeshDims from torch.distributed.tensor import DeviceMesh @@ -85,6 +88,7 @@ def replicate( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> ReplicateModule: ... @@ -97,6 +101,7 @@ def replicate( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> list[ReplicateModule]: ... @@ -108,6 +113,7 @@ def replicate( mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), offload_policy: OffloadPolicy = OffloadPolicy(), ignored_params: set[nn.Parameter] | None = None, + dp_mesh_dims: DataParallelMeshDims | None = None, ): r"""Replicates a module @@ -122,8 +128,17 @@ def replicate( torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp") _validate_module(module) mesh = mesh or _init_default_mesh(mesh_dim_names=("replicate",)) - _validate_mesh(mesh) - mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) + if dp_mesh_dims is not None: + _validate_mesh_common(mesh, dp_mesh_dims) + mesh_info = _get_mesh_info(mesh, dp_mesh_dims) + if not isinstance(mesh_info, DDPMeshInfo): + raise ValueError( + "replicate() with dp_mesh_dims requires replicate-only " + "dims (no shard dims). Use fully_shard() for sharding." + ) + else: + _validate_mesh(mesh) + mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) device = _get_device_from_mesh(mesh) # managed_modules (3rd return) and buffers (5th return) are unused: # - managed_modules: FSDP uses this to set Dynamo-specific attributes diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 8b5807fcca8b7..cf0c00a48dda9 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -3,6 +3,7 @@ # flake8: noqa: F821 from collections.abc import Callable, Collection, Mapping, MutableMapping from typing import cast, TypeVar +from typing_extensions import TypeIs import torch from torch.distributed._shard.sharded_tensor.api import ShardedTensor @@ -20,7 +21,7 @@ __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] -def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> TypeIs[torch.Tensor]: return isinstance(value, torch.Tensor) diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py index 4aa4854db2358..4d9817f235ca4 100644 --- a/torch/distributed/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -8,7 +8,15 @@ def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: - return (exc, tb.extract_tb(exc.__traceback__)) + summary = tb.extract_tb(exc.__traceback__) + # Python 3.13+ stores bytecode objects in FrameSummary._code, + # which cannot be pickled. Clear them so gather_object succeeds + # and the real exception is reported instead of a misleading + # "cannot pickle code objects" TypeError. + for frame in summary: + if hasattr(frame, "_code"): + object.__setattr__(frame, "_code", None) + return (exc, summary) def _is_wrapped_exception(obj: Any) -> bool: diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 895ba0e17bb31..c84544d2cceaa 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -6,12 +6,14 @@ import warnings from collections.abc import Iterator from itertools import zip_longest -from typing import TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch +from torch._opaque_base import OpaqueBase from torch.distributed import is_available from torch.distributed._mesh_layout import _MeshLayout from torch.distributed._pycute import IntTuple, is_int, suffix_product +from torch.types import IntLikeType from torch.utils._typing_utils import not_none @@ -148,7 +150,7 @@ def _get_device_handle(device_type: str = "cuda"): """ return getattr(torch, device_type, None) - class DeviceMesh(torch._opaque_base.OpaqueBase): + class DeviceMesh(OpaqueBase): """ DeviceMesh represents a mesh of devices, where layout of devices could be represented as a n-d dimension array, and each value of the n-d dimensional @@ -202,6 +204,7 @@ class DeviceMesh(torch._opaque_base.OpaqueBase): _mesh_dim_names: tuple[str, ...] | None _layout: _MeshLayout _root_mesh: "DeviceMesh | None" = None + _thread_id: int | None # Record flatten mesh name to its flattened mesh in root mesh. _flatten_mapping: dict[str, "DeviceMesh"] # Registry mapping group names to ProcessGroup objects (to avoid C++ lookup) @@ -643,19 +646,21 @@ def __repr__(self) -> str: device_mesh_repr += f", Mesh: {self.mesh.tolist()}" return f"{device_mesh_repr})" + def _hash_key(self) -> tuple[Any, ...]: + """Return the tuple used for hashing. Used by both __hash__ and _stable_hash.""" + return ( + self._flatten_rank_map, + self._layout, + self._device_type, + self._mesh_dim_names, + self._thread_id, + ) + def __hash__(self): # lazily compute hash self._hash = getattr(self, "_hash", None) if not self._hash: - self._hash = hash( - ( - self._flatten_rank_map, - self._layout, - self._device_type, - self._mesh_dim_names, - self._thread_id, - ) - ) + self._hash = hash(self._hash_key()) return self._hash def __eq__(self, other: object) -> bool: @@ -671,6 +676,17 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) + def _stable_hash(self) -> str: + """ + Return a stable hash for AOT autograd caching. + [See note: Tensor subclass stable hashing for AOT autograd cache] + """ + import hashlib + + return hashlib.blake2b( + repr(self._hash_key()).encode(), digest_size=16 + ).hexdigest() + def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. @@ -1220,7 +1236,7 @@ def get_coordinate(self) -> tuple[int, ...] | None: """ return self._coordinate_on_dim - def _sym_get_coordinate(self, index: int) -> int: + def _sym_get_coordinate(self, index: int) -> IntLikeType: import torch.distributed.config as config from torch._guards import detect_fake_mode @@ -1590,6 +1606,7 @@ def _register_distributed_opaque_types(): "rank": MemberType.USE_REAL, "_get_backend_name": MemberType.USE_REAL, "group_name": MemberType.USE_REAL, + "group_desc": MemberType.USE_REAL, "__eq__": MemberType.USE_REAL, }, ) diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index 1e4219250c39d..7c0d55e37d688 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -1,6 +1,7 @@ from ._flat_param import FlatParameter as FlatParameter from ._fully_shard import ( CPUOffloadPolicy, + DataParallelMeshDims, FSDPModule, fully_shard, MixedPrecisionPolicy, @@ -49,6 +50,7 @@ "StateDictType", # FSDP2 "CPUOffloadPolicy", + "DataParallelMeshDims", "FSDPModule", "fully_shard", "MixedPrecisionPolicy", @@ -60,6 +62,7 @@ # Set namespace for exposed private names CPUOffloadPolicy.__module__ = "torch.distributed.fsdp" +DataParallelMeshDims.__module__ = "torch.distributed.fsdp" FSDPModule.__module__ = "torch.distributed.fsdp" fully_shard.__module__ = "torch.distributed.fsdp" MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp" diff --git a/torch/distributed/fsdp/_fully_shard/__init__.py b/torch/distributed/fsdp/_fully_shard/__init__.py index d4d0b341a3f82..f6aab00fc32b7 100644 --- a/torch/distributed/fsdp/_fully_shard/__init__.py +++ b/torch/distributed/fsdp/_fully_shard/__init__.py @@ -1,4 +1,9 @@ -from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_api import ( + CPUOffloadPolicy, + DataParallelMeshDims, + MixedPrecisionPolicy, + OffloadPolicy, +) from ._fully_shard import ( FSDPModule, fully_shard, @@ -10,6 +15,7 @@ __all__ = [ "CPUOffloadPolicy", + "DataParallelMeshDims", "FSDPModule", "fully_shard", "MixedPrecisionPolicy", diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py index d495bb953cac3..3f03d5707a9ba 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py @@ -126,6 +126,49 @@ def __call__( ) -> dist.Work | None: ... +@dataclass +class DataParallelMeshDims: + """ + Specifies which dimensions of a full SPMD :class:`DeviceMesh` correspond to + data parallelism when using :func:`fully_shard` whose parameters are already + DTensors on that mesh. + + Attributes: + shard (Optional[Union[str, tuple[str, ...]]]): Mesh dimension name(s) + that FSDP shards parameters on. If a tuple of names, those dims + are flattened into a single shard dimension. At least one of + ``shard`` and ``replicate`` must be set. + replicate (Optional[Union[str, tuple[str, ...]]]): Mesh dimension + name(s) for HSDP or DDP replication. If a tuple of names, those + dims are flattened into a single replicate dimension. + """ + + shard: str | tuple[str, ...] | None = None + replicate: str | tuple[str, ...] | None = None + + def __post_init__(self): + if self.shard is None and self.replicate is None: + raise ValueError( + "At least one of shard or replicate must be set in DataParallelMeshDims" + ) + + @property + def shard_names(self) -> tuple[str, ...]: + if self.shard is None: + return () + if isinstance(self.shard, str): + return (self.shard,) + return tuple(self.shard) + + @property + def replicate_names(self) -> tuple[str, ...]: + if self.replicate is None: + return () + if isinstance(self.replicate, str): + return (self.replicate,) + return tuple(self.replicate) + + @dataclass class OffloadPolicy: """ diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 751a7d6b31338..3c8aa312c7187 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -1,10 +1,11 @@ import math from collections.abc import Callable, Sequence from itertools import chain -from typing import Any, cast, NamedTuple +from typing import Any, cast, Literal, NamedTuple import torch import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem from torch.distributed.device_mesh import _get_device_handle from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather, ReduceScatter @@ -77,6 +78,36 @@ def allocate( return torch.empty(*size, dtype=dtype, device=device) +class SymmMemAllocMixin: + def __init__( + self, + group: dist.ProcessGroup, + backend: Literal["NCCL"] = "NCCL", + *args: Any, + **kwargs: Any, + ): + self._group = group + symm_mem.set_backend(backend) + # Force initialization of communicator; otherwise, the rendezvous may + # see empty communicator. + # TODO: Remove this, maybe by warning user to perform eager dist init. + # For now, it is okay since it isjust a one-time cost at init. + dist.barrier(group=group) + + def allocate( + self, + size: Sequence[int | torch.SymInt], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + # Leverage MemPool to reuse the symmetric buffer, avoiding allocation + # and rendezvous overhead + mempool = symm_mem.get_mem_pool(device) + with torch.cuda.use_mem_pool(mempool): + return torch.empty(size, dtype=dtype, device=device) + + class DefaultAllGather(DefaultAllocMixin, AllGather): def __call__( self, @@ -112,6 +143,35 @@ def __call__( ) +class SymmMemAllGather(SymmMemAllocMixin, AllGather): + def __init__( + self, + group: dist.ProcessGroup, + backend: Literal["NCCL"] = "NCCL", + ) -> None: + super().__init__(group, backend) + + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> dist.Work | None: + # We are doing inplace all-gather, so we need to rendezvous the output tensor only + symm_mem.rendezvous(output_tensor, group=group.group_name) + # Calling regular all-gather would already cause libraries like NCCL to + # use its optimized all-gather implementation for symmetric memory: + # - Copy Engine All-Gather (when zero-CTA policy is enabled) + # - Symmetric Kernel All-Gather (when zero-CTA policy is not enabled) + return dist.all_gather_into_tensor( + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + class DefaultReduceScatter(DefaultAllocMixin, ReduceScatter): def __call__( self, @@ -151,6 +211,35 @@ def __call__( ) +class SymmMemReduceScatter(SymmMemAllocMixin, ReduceScatter): + def __init__( + self, + group: dist.ProcessGroup, + backend: Literal["NCCL"] = "NCCL", + ) -> None: + super().__init__(group, backend) + + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> dist.Work | None: + symm_mem.rendezvous(input_tensor, group=group.group_name) + symm_mem.rendezvous(output_tensor, group=group.group_name) + # Calling regular reduce-scatter would already cause libraries like NCCL to + # use its optimized reduce-scatter implementation for symmetric memory + return dist.reduce_scatter_tensor( + output=output_tensor, + input=input_tensor, + group=group, + op=op, + async_op=async_op, + ) + + @torch.library.impl(lib, "all_gather_copy_in", "Meta") def all_gather_copy_in_meta( all_gather_inputs: list[torch.Tensor], diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index 2f76336332e85..f96f634b86f51 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -2,7 +2,7 @@ import functools import math import traceback -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import auto, Enum from typing import Any @@ -13,6 +13,8 @@ from torch.distributed.tensor import DeviceMesh, DTensor, Shard from torch.distributed.tensor._dtensor_spec import DTensorSpec +from ._fsdp_api import DataParallelMeshDims + def _dynamo_disable(func): """Disable dynamo tracing for FSDP hooks.""" @@ -31,12 +33,19 @@ class DataParallelMeshInfo: mesh: DeviceMesh shard_mesh_dim: int | None = None replicate_mesh_dim: int | None = None + dp_mesh_dims: DataParallelMeshDims | None = None + # The full SPMD mesh (excluding PP dims) that params are distributed on. + # Must include all non-PP SPMD dims (e.g. DP + TP); passing a submesh + # that omits dims like TP will lead to incorrect behavior. + spmd_mesh: DeviceMesh | None = field(default=None, repr=False) + is_spmd_mesh: bool = field(default=False, init=False, repr=False) def __post_init__(self): if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: raise AssertionError( "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" ) + self.is_spmd_mesh = self.dp_mesh_dims is not None @dataclass diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py index 0b1d652422852..42f5293705cda 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py @@ -13,6 +13,7 @@ from ._fsdp_common import ( _is_composable_with_fsdp, DataParallelMeshInfo, + DDPMeshInfo, FSDPMeshInfo, HSDPMeshInfo, ) @@ -23,7 +24,7 @@ from collections.abc import Callable from typing import Any - from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy + from ._fsdp_api import DataParallelMeshDims, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ShardPlacementFnResult from ._fsdp_state import FSDPState @@ -46,13 +47,35 @@ def _validate_module(module: nn.Module, func_name: str) -> None: ) -def _validate_mesh(mesh: "DeviceMesh") -> None: +def _validate_mesh( + mesh: "DeviceMesh", + dp_mesh_dims: "DataParallelMeshDims | None" = None, +) -> None: """ Validate that the mesh can be used with fully_shard. - Raises ValueError if the mesh is not 1D or 2D. - Raises AssertionError if the mesh is 2D but mesh_dim_names is not specified. + When ``dp_mesh_dims`` is provided, validates that the named dims + exist in the mesh and at least one of shard/replicate is set. + Otherwise raises ValueError if the mesh is not 1D or 2D. """ + if dp_mesh_dims is not None: + if dp_mesh_dims.shard is None and dp_mesh_dims.replicate is None: + raise ValueError( + "At least one of shard or replicate must be set in dp_mesh_dims" + ) + if mesh.mesh_dim_names is None: + raise ValueError( + "mesh must have mesh_dim_names when dp_mesh_dims is provided" + ) + names_to_check: list[str] = list(dp_mesh_dims.shard_names) + names_to_check.extend(dp_mesh_dims.replicate_names) + for name in names_to_check: + if name not in mesh.mesh_dim_names: + raise ValueError( + f"Mesh dim name '{name}' not found in mesh.mesh_dim_names " + f"{mesh.mesh_dim_names}" + ) + return if mesh.ndim not in (1, 2): raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") if mesh.ndim == 2 and mesh.mesh_dim_names is None: @@ -61,18 +84,71 @@ def _validate_mesh(mesh: "DeviceMesh") -> None: ) -def _get_mesh_info(mesh: "DeviceMesh") -> "FSDPMeshInfo": +def _get_mesh_info( + mesh: "DeviceMesh", + dp_mesh_dims: "DataParallelMeshDims | None" = None, +) -> "DataParallelMeshInfo": """ Get the appropriate mesh info for the given mesh. + When ``dp_mesh_dims`` is provided, extracts the DP submesh from the + full SPMD mesh and returns FSDPMeshInfo, HSDPMeshInfo, or DDPMeshInfo + with ``dp_mesh_dims`` set and ``is_spmd_mesh`` as True. + Returns FSDPMeshInfo for 1D mesh, HSDPMeshInfo for 2D mesh. """ + if dp_mesh_dims is not None: + return _get_mesh_info_from_named_dims(mesh, dp_mesh_dims) if mesh.ndim == 1: return FSDPMeshInfo(mesh, shard_mesh_dim=0) else: return HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) +def _get_mesh_info_from_named_dims( + mesh: "DeviceMesh", + dp_mesh_dims: "DataParallelMeshDims", +) -> "DataParallelMeshInfo": + shard_names = dp_mesh_dims.shard_names + replicate_names = dp_mesh_dims.replicate_names + + def _get_submesh(names: tuple[str, ...]) -> "DeviceMesh": + if len(names) == 1: + return mesh[names[0]] + # Flatten multi-dim submesh into a single dim so FSDP's internal + # logic (which expects one shard and/or one replicate dim) works + # unchanged. This creates a new 1D DeviceMesh and ProcessGroup. + return mesh[names]._flatten("_".join(names)) + + if len(shard_names) == 0: # DDP + dp_mesh = _get_submesh(replicate_names) + return DDPMeshInfo( + dp_mesh, + replicate_mesh_dim=0, + dp_mesh_dims=dp_mesh_dims, + spmd_mesh=mesh, + ) + if len(replicate_names) == 0: # FSDP + dp_mesh = _get_submesh(shard_names) + return FSDPMeshInfo( + dp_mesh, + shard_mesh_dim=0, + dp_mesh_dims=dp_mesh_dims, + spmd_mesh=mesh, + ) + # HSDP + shard_mesh = _get_submesh(shard_names) + replicate_mesh = _get_submesh(replicate_names) + dp_mesh = DeviceMesh._concatenate([replicate_mesh, shard_mesh]) + return HSDPMeshInfo( + dp_mesh, + shard_mesh_dim=1, + replicate_mesh_dim=0, + dp_mesh_dims=dp_mesh_dims, + spmd_mesh=mesh, + ) + + def _get_post_forward_mesh_info( reshard_after_forward: bool | int, mesh_info: FSDPMeshInfo ) -> FSDPMeshInfo | None: diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index dca991147ad54..8b1227ee6d92e 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -176,8 +176,9 @@ class FSDPParam: _unsharded_param: nn.Parameter # ND unsharded_accumulated_grad: torch.Tensor | None # ND _sharding_spec: DTensorSpec - # DTensor attributes (only defined for DTensor `param`): - _tp_spec: DTensorSpec + _unsharded_dtensor_spec: ( + DTensorSpec | None + ) # set for DTensor params (SPMD or TP/EP) all_gather_outputs: list[torch.Tensor] # 1D # All-gather extension attributes _extensions_data: ExtensionsData @@ -261,69 +262,7 @@ def _init_sharded_param( # https://github.com/pytorch/pytorch/issues/113045 self.is_dtensor = isinstance(param, DTensor) self._orig_param_uid = _get_orig_param_uid(param) - if self.is_dtensor: - self._tp_spec = cast(DTensor, param)._spec - dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) - if dp_mesh is None or tp_mesh is None: - raise AssertionError( - "FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n" - f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}" - ) - self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh]) - if len(self._tp_spec.placements) > 2: - raise NotImplementedError( - f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}" - ) - split_factor = self._tp_spec.num_shards_map[shard_dim] - if not (2 <= self._spmd_mesh.ndim <= 4): - raise AssertionError( - "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " - f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." - ) - self._spmd_placements: tuple[Placement, ...] - if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP - dp_shard_tp_placement = ( - ( - _StridedShard(shard_dim, split_factor=split_factor) - if split_factor > 1 - else fsdp_placement - ), - *self._tp_spec.placements, - ) - else: # DDP - dp_shard_tp_placement = ( - (Replicate()), - *self._tp_spec.placements, - ) - if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP - if self.mesh_info.replicate_mesh_dim != 0: - raise AssertionError( - f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}" - ) - self._spmd_placements = (Replicate(),) + dp_shard_tp_placement - else: # FSDP or DDP - self._spmd_placements = dp_shard_tp_placement - - self._sharding_spec = DTensorSpec( - self._spmd_mesh, - self._spmd_placements, - tensor_meta=self._tp_spec.tensor_meta, - ) - param_data = cast(DTensor, param)._local_tensor - else: - self._spmd_mesh = self.mesh_info.mesh - if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP - self._spmd_placements = (Replicate(), fsdp_placement) - elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP - self._spmd_placements = (fsdp_placement,) - elif isinstance(self.mesh_info, DDPMeshInfo): # DDP - self._spmd_placements = (Replicate(),) - self._sharding_spec = DTensorSpec( - self._spmd_mesh, - self._spmd_placements, - tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), - ) - param_data = param + param_data = self._init_sharding_spec(param, fsdp_placement, shard_dim) if not param_data.is_contiguous(): raise AssertionError( f"Expected contiguous tensor, got {param_data.shape=} {param_data.stride()=}" @@ -385,6 +324,180 @@ def _init_sharded_param( self._setattr_on_modules(self.sharded_param) self.sharded_state = ShardedState.SHARDED + def _init_sharding_spec( + self, + param: nn.Parameter, + fsdp_placement: Shard, + shard_dim: int, + ) -> torch.Tensor: + """ + Build ``_sharding_spec``, ``_spmd_mesh``, and ``_spmd_placements`` and + return the local tensor data to be sharded. + """ + self._unsharded_dtensor_spec = None + if self.mesh_info.is_spmd_mesh and not self.is_dtensor: + raise ValueError( + "When dp_mesh_dims is provided, all parameters must be " + "DTensors on the full SPMD mesh (e.g. via distribute_module). " + f"Got plain tensor for parameter '{self._module_info.param_name}'." + ) + if self.is_dtensor and self.mesh_info.is_spmd_mesh: + return self._init_sharding_spec_spmd(param, fsdp_placement, shard_dim) + if self.is_dtensor: + return self._init_sharding_spec_tp(param, fsdp_placement, shard_dim) + return self._init_sharding_spec_plain(param, fsdp_placement) + + def _init_sharding_spec_spmd( + self, + param: nn.Parameter, + fsdp_placement: Shard, + shard_dim: int, + ) -> torch.Tensor: + """SPMD path: param is a DTensor on the full SPMD mesh.""" + self._unsharded_dtensor_spec = cast(DTensor, param)._spec + spmd_mesh = self._unsharded_dtensor_spec.mesh + dp_dim_names = self.mesh_info.dp_mesh_dims + if dp_dim_names is None: + raise AssertionError("dp_dim_names must not be None for SPMD mesh") + if spmd_mesh.mesh_dim_names is None: + raise AssertionError("spmd_mesh.mesh_dim_names must not be None") + if ( + self.mesh_info.spmd_mesh is not None + and spmd_mesh is not self.mesh_info.spmd_mesh + ): + raise ValueError( + "Expected param's DTensor mesh to be the same mesh passed " + "to fully_shard, but got different mesh objects" + ) + + dp_shard_indices = [ + spmd_mesh.mesh_dim_names.index(n) for n in dp_dim_names.shard_names + ] + + orig_placements = self._unsharded_dtensor_spec.placements + for idx in dp_shard_indices: + if not isinstance(orig_placements[idx], Replicate): + raise ValueError( + f"Expected Replicate() on DP shard dim " + f"'{spmd_mesh.mesh_dim_names[idx]}' (index {idx}) " + f"but got {orig_placements[idx]}" + ) + dp_replicate_indices = [] + for rep_name in dp_dim_names.replicate_names: + rep_idx = spmd_mesh.mesh_dim_names.index(rep_name) + dp_replicate_indices.append(rep_idx) + if not isinstance(orig_placements[rep_idx], Replicate): + raise ValueError( + f"Expected Replicate() on DP replicate dim " + f"'{spmd_mesh.mesh_dim_names[rep_idx]}' (index {rep_idx}) " + f"but got {orig_placements[rep_idx]}" + ) + + # Cache DP dim indices so _get_grad_inner_tensor can skip + # redistribution on DP dims and let FSDP's reduce-scatter handle them. + self._dp_dim_indices: frozenset[int] = frozenset( + dp_shard_indices + dp_replicate_indices + ) + + new_placements = list(orig_placements) + for dp_idx in dp_shard_indices: + # split_factor = number of non-DP shards on shard_dim from + # mesh dims with higher index (the "right-side" dims that + # _StridedShard needs to interleave with) + sf = 1 + for j in range(dp_idx + 1, spmd_mesh.ndim): + p = orig_placements[j] + if isinstance(p, (Shard, _StridedShard)) and p.dim == shard_dim: + sf *= spmd_mesh.size(j) + new_placements[dp_idx] = ( + _StridedShard(shard_dim, split_factor=sf) if sf > 1 else fsdp_placement + ) + + self._spmd_mesh = spmd_mesh + self._spmd_placements: tuple[Placement, ...] = tuple(new_placements) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._unsharded_dtensor_spec.tensor_meta, + ) + return cast(DTensor, param)._local_tensor + + def _init_sharding_spec_tp( + self, + param: nn.Parameter, + fsdp_placement: Shard, + shard_dim: int, + ) -> torch.Tensor: + """TP/EP path: param is a DTensor, DP mesh is separate from TP mesh.""" + self._unsharded_dtensor_spec = cast(DTensor, param)._spec + dp_mesh, tp_mesh = (self.mesh_info.mesh, self._unsharded_dtensor_spec.mesh) + if dp_mesh is None or tp_mesh is None: + raise AssertionError( + "FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n" + f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}" + ) + self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh]) + if len(self._unsharded_dtensor_spec.placements) > 2: + raise NotImplementedError( + f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._unsharded_dtensor_spec.placements}" + ) + split_factor = self._unsharded_dtensor_spec.num_shards_map[shard_dim] + if not (2 <= self._spmd_mesh.ndim <= 4): + raise AssertionError( + "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " + f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." + ) + if isinstance(self.mesh_info, FSDPMeshInfo): + dp_shard_tp_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else fsdp_placement + ), + *self._unsharded_dtensor_spec.placements, + ) + else: # DDP + dp_shard_tp_placement = ( + Replicate(), + *self._unsharded_dtensor_spec.placements, + ) + self._spmd_placements: tuple[Placement, ...] + if isinstance(self.mesh_info, HSDPMeshInfo): + if self.mesh_info.replicate_mesh_dim != 0: + raise AssertionError( + f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}" + ) + self._spmd_placements = (Replicate(),) + dp_shard_tp_placement + else: + self._spmd_placements = dp_shard_tp_placement + + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._unsharded_dtensor_spec.tensor_meta, + ) + return cast(DTensor, param)._local_tensor + + def _init_sharding_spec_plain( + self, + param: nn.Parameter, + fsdp_placement: Shard, + ) -> torch.Tensor: + """Plain tensor path: param is not a DTensor.""" + self._spmd_mesh = self.mesh_info.mesh + if isinstance(self.mesh_info, HSDPMeshInfo): + self._spmd_placements = (Replicate(), fsdp_placement) + elif isinstance(self.mesh_info, FSDPMeshInfo): + self._spmd_placements = (fsdp_placement,) + elif isinstance(self.mesh_info, DDPMeshInfo): + self._spmd_placements = (Replicate(),) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), + ) + return param + def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: mesh_info = self.post_forward_mesh_info if mesh_info is None: @@ -493,8 +606,10 @@ def init_unsharded_param(self): self._contiguous_orig_stride, storage_offset=0, ) - if self.is_dtensor: - unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) + if self._unsharded_dtensor_spec is not None: + unsharded_param = _from_local_no_grad( + unsharded_param, self._unsharded_dtensor_spec + ) self._unsharded_param = nn.Parameter( unsharded_param, requires_grad=self.sharded_param.requires_grad ) @@ -515,7 +630,7 @@ def to_sharded(self) -> None: def to_sharded_post_forward(self) -> None: if self.is_dtensor: raise NotImplementedError( - "Resharding to smaller mesh with TP is not supported yet" + "Resharding to smaller mesh is not supported for DTensor parameters yet" ) self._assert_in_states(ShardedState.UNSHARDED) if self.post_forward_mesh_info is None: @@ -738,13 +853,32 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: grad = grad.wait() if not isinstance(grad, DTensor): raise AssertionError(f"Expected DTensor, got {type(grad)}") - placements = self._tp_spec.placements - if placements != grad.placements: - if len(self._tp_spec.placements) != len(grad.placements): - raise AssertionError( - f"Expected same placement length: {self._tp_spec=} {grad.placements=}" - ) - grad = grad.redistribute(placements=placements) + if self._unsharded_dtensor_spec is None: + raise AssertionError( + "Expected _unsharded_dtensor_spec for DTensor param" + ) + placements = self._unsharded_dtensor_spec.placements + if self.mesh_info.is_spmd_mesh: + # Only redistribute non-DP dims; keep Partial on DP dims + # so FSDP's reduce-scatter handles them directly, avoiding + # a redundant all-reduce on the DP dimensions. + target_placements = tuple( + grad.placements[i] if i in self._dp_dim_indices else placements[i] + for i in range(len(placements)) + ) + if target_placements != grad.placements: + if len(placements) != len(grad.placements): + raise AssertionError( + f"Expected same placement length: {placements=} {grad.placements=}" + ) + grad = grad.redistribute(placements=target_placements) + else: + if placements != grad.placements: + if len(placements) != len(grad.placements): + raise AssertionError( + f"Expected same placement length: {placements=} {grad.placements=}" + ) + grad = grad.redistribute(placements=placements) grad = grad._local_tensor return grad diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 3cccd934f9d0d..0b71c63720e05 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -3,7 +3,7 @@ import contextlib import logging -from typing import Any, cast, NamedTuple, TYPE_CHECKING +from typing import Any, cast, Literal, NamedTuple, TYPE_CHECKING import torch import torch.distributed as dist @@ -29,6 +29,8 @@ ProcessGroupAllocAllGather, ProcessGroupAllocReduceScatter, ReduceScatter, + SymmMemAllGather, + SymmMemReduceScatter, ) from ._fsdp_common import ( _dynamo_disable, @@ -275,6 +277,28 @@ def lazy_init(self): self._init_mp_dtypes() self._register_state_dict_hooks() + def set_symm_mem(self, backend: Literal["NCCL"] = "NCCL") -> None: + if not isinstance(self._all_gather_comm, (DefaultAllGather | SymmMemAllGather)): + raise AssertionError( + "cannot call set_symm_mem() " + f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}" + ) + self._all_gather_comm = SymmMemAllGather( + self._all_gather_process_group, backend + ) + if not isinstance( + self._reduce_scatter_comm, (DefaultReduceScatter | SymmMemReduceScatter) + ): + raise AssertionError( + "cannot call set_symm_mem() " + f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}" + ) + if self.force_sum_reduction_for_comms: + # As of NCCL 2.29.3, NCCL symmetric reduce-scatter only supports SUM reduction + self._reduce_scatter_comm = SymmMemReduceScatter( + self._reduce_scatter_process_group, backend + ) + def set_allocate_memory_from_process_group(self, enable: bool) -> None: """ Whether to (try to) use the ProcessGroup's allocate_tensor method for diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index f3237bbdea11e..cc48a5f9d90d3 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -5,14 +5,20 @@ import functools from contextlib import contextmanager -from typing import Any, cast, NoReturn, overload, TYPE_CHECKING +from typing import Any, cast, Literal, NoReturn, overload, TYPE_CHECKING from typing_extensions import deprecated import torch import torch.nn as nn from torch.distributed._composable import contract -from ._fsdp_api import AllGather, MixedPrecisionPolicy, OffloadPolicy, ReduceScatter +from ._fsdp_api import ( + AllGather, + DataParallelMeshDims, + MixedPrecisionPolicy, + OffloadPolicy, + ReduceScatter, +) from ._fsdp_common import FSDPMeshInfo, ShardPlacementFnResult from ._fsdp_init import ( _apply_to_module, @@ -64,6 +70,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> FSDPModule: ... @@ -78,6 +85,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> list[FSDPModule]: ... @@ -96,6 +104,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), offload_policy: OffloadPolicy = OffloadPolicy(), ignored_params: set[nn.Parameter] | None = None, + dp_mesh_dims: DataParallelMeshDims | None = None, ): """ Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP @@ -194,6 +203,12 @@ def fully_shard( ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be ignored by FSDP. They will not be sharded, nor moved to the device during init, nor have their gradients reduced in backward. + dp_mesh_dims (Optional[DataParallelMeshDims]): When provided, + ``mesh`` is treated as the full SPMD mesh, and parameters should be + DTensors on this mesh with ``Replicate()`` on all DP dimensions. + The ``shard`` field names which dim(s) FSDP shards on (multiple + dims are flattened). The ``replicate`` field names the HSDP + replication dim(s) (multiple dims are flattened). Returns: FSDPModule: The module with FSDP applied (in-place). @@ -201,16 +216,29 @@ def fully_shard( torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") _validate_module(module, "fully_shard") mesh = mesh or _init_default_mesh() - _validate_mesh(mesh) - mesh_info = _get_mesh_info(mesh) + _validate_mesh(mesh, dp_mesh_dims) + mesh_info = _get_mesh_info(mesh, dp_mesh_dims) device = _get_device_from_mesh(mesh) auto_reshard_after_forward = reshard_after_forward is None # If the user does not provide ``reshard_after_forward``, we set it to True. # During lazy_init, we identify which module is the root and override its value to False - post_forward_mesh_info = _get_post_forward_mesh_info( - reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] - mesh_info, - ) + if isinstance(mesh_info, FSDPMeshInfo): + if ( + mesh_info.is_spmd_mesh + and not isinstance(reshard_after_forward, bool) + and isinstance(reshard_after_forward, int) + ): + raise NotImplementedError( + "reshard_after_forward as int is not yet supported with " + "SPMD mesh (dp_mesh_dims)" + ) + post_forward_mesh_info = _get_post_forward_mesh_info( + reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] + mesh_info, + ) + else: + # DDPMeshInfo: no sharding, so no post-forward resharding needed + post_forward_mesh_info = None arg_module, modules, managed_modules, params, buffers = _get_modules_and_states( module, device, ignored_params ) @@ -629,6 +657,40 @@ def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None: for fsdp_param_group in state._fsdp_param_groups: fsdp_param_group.set_allocate_memory_from_process_group(enable) + def set_symm_mem_for_comm(self, backend: Literal["NCCL"] = "NCCL") -> None: + """ + Sets the symmetric memory (``symm_mem``) backend for allocating the + staging buffers used in all-gather collectives. This allows NCCL to use + optimized all-gather implementations via symmetric memory. Such + optimization may depend on the topology of the system. For single node, + Copy Engine All-Gather may be used. For multi-node, Symmetric Kernel + All-Gather may be used. + + To enable Copy Engine All-Gather, you need to set the NCCL process group + with the zero-CTA policy. + ```python + opts = dist.ProcessGroupNCCL.Options() + opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO + dist.init_process_group(backend="nccl", pg_options=opts, device_id=device) + ``` + Alternatively, you can set the environment variable `NCCL_CTA_POLICY` to 2. + ```bash + export NCCL_CTA_POLICY=2 + ``` + For more details, see [Copy Engine + Collectives](https://docs.pytorch.org/docs/2.11/symmetric_memory.html#copy-engine-collectives). + + This cannot be used together with :meth:`set_custom_all_gather` or + :meth:`set_custom_reduce_scatter`. + + Args: + backend (str): The symmetric memory backend to use. Defaults to + ``"NCCL"``. Currently, only ``"NCCL"`` is supported. + """ + state = self._get_fsdp_state() + for fsdp_param_group in state._fsdp_param_groups: + fsdp_param_group.set_symm_mem(backend) + def _set_unshard_async_op(self, async_op: bool): """ Sets whether to use ``async_op=True`` or ``False`` for the pre-forward diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 370bccb9a0bcd..2e3d857564a91 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -2,6 +2,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import copy +import hashlib import inspect import warnings from collections.abc import Callable, Sequence @@ -389,6 +390,15 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): requires_grad=requires_grad, ) + def _stable_hash_for_caching(self) -> str: + """ + Return a stable hash for AOT autograd caching. + [See note: Tensor subclass stable hashing for AOT autograd cache] + """ + # Combine spec's stable hash with requires_grad + cache_data = self._spec._stable_hash() + str(self.requires_grad) + return hashlib.blake2b(cache_data.encode(), digest_size=16).hexdigest() + def __coerce_tangent_metadata__(self): if not any(isinstance(p, Partial) for p in self.placements): return self diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index ec8292d55de3f..cd2e69309c45f 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -10,6 +10,7 @@ import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch._C._distributed_c10d import _resolve_process_group from torch._logging import warning_once +from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed._local_tensor import ( local_tensor_mode, maybe_run_for_local_tensor, @@ -24,6 +25,8 @@ scatter, Work, ) +from torch.fx.experimental.symbolic_shapes import guard_or_false +from torch.types import IntLikeType logger = logging.getLogger(__name__) @@ -180,21 +183,28 @@ def mesh_broadcast( @maybe_run_for_local_tensor -def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: - from torch.fx.experimental.symbolic_shapes import guard_or_false - - if guard_or_false(pad_size == 0): +def pad_tensor( + tensor: torch.Tensor, pad_dim: int, pad_size: IntLikeType +) -> torch.Tensor: + # During tracing, always emit the pad op even when pad_size=0 so all + # ranks produce identical FX graph structure (SPMD). + # guard_or_false returns False for symbolic sizes, so the pad is always + # emitted during tracing. In eager with concrete pad_size=0, it returns + # True and we skip the no-op pad. + if guard_or_false(pad_size == 0) and not _are_we_tracing(): return tensor pad = [0, 0] * (tensor.ndim - pad_dim) - pad[-1] = pad_size + pad[-1] = pad_size # pyrefly: ignore[unsupported-operation] return torch.nn.functional.pad(tensor, pad) @maybe_run_for_local_tensor -def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: - from torch.fx.experimental.symbolic_shapes import guard_or_false - - if guard_or_false(pad_size == 0): +def unpad_tensor( + tensor: torch.Tensor, pad_dim: int, pad_size: IntLikeType +) -> torch.Tensor: + # During tracing, always emit the narrow op even when pad_size=0 so all + # ranks produce identical FX graph structure (SPMD). + if guard_or_false(pad_size == 0) and not _are_we_tracing(): return tensor return tensor.narrow( pad_dim, diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index d1f7125bebb25..f40acf5fdcc18 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -1,3 +1,4 @@ +import hashlib import itertools import math from collections import defaultdict @@ -415,24 +416,26 @@ def __setattr__(self, attr: str, value: Any) -> None: if not isinstance(value, TensorMeta | TensorMetadata): raise AssertionError(repr(value)) + def _hash_key(self) -> tuple[Any, ...]: + """Return the tuple used for hashing. Used by both __hash__ and _stable_hash.""" + if self.tensor_meta is not None: + return ( + self.mesh, + self.placements, + self.shard_order, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + return (self.mesh, self.placements, self.shard_order) + def _hash_impl(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding # propagation results. We only need to consider the mesh, placements, shape # dtype and stride. # Caveat: we need to keep this in mind and sync hash and eq if we add more # fields to them. - if self.tensor_meta is not None: - return hash( - ( - self.mesh, - self.placements, - self.shard_order, - self.tensor_meta.shape, - self.tensor_meta.stride, - self.tensor_meta.dtype, - ) - ) - return hash((self.mesh, self.placements, self.shard_order)) + return hash(self._hash_key()) def __hash__(self) -> int: # We lazily cache the spec to avoid recomputing the hash upon each @@ -443,6 +446,17 @@ def __hash__(self) -> int: self._hash = self._hash_impl() return self._hash + def _stable_hash(self) -> str: + """ + Return a stable hash for AOT autograd caching. + [See note: Tensor subclass stable hashing for AOT autograd cache] + """ + # Get hash key, but replace mesh with its stable hash + key = self._hash_key() + # First element is mesh, replace with its stable hash + stable_key = (self.mesh._stable_hash(),) + key[1:] + return hashlib.blake2b(repr(stable_key).encode(), digest_size=16).hexdigest() + def _check_equals(self, other: object, skip_shapes: bool = False) -> bool: if not ( isinstance(other, DTensorSpec) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index b1aca77f41323..d413c67169171 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -436,6 +436,9 @@ def convert_to_meta(item): return item.tensor_meta elif isinstance(item, TupleStrategy): return tuple(convert_to_meta(child) for child in item.children) + elif isinstance(item, (list, tuple)): + converted = [convert_to_meta(child) for child in item] + return type(item)(converted) else: return item @@ -450,6 +453,9 @@ def convert_to_meta(item): return item.tensor_meta elif isinstance(item, TupleStrategy): return tuple(convert_to_meta(child) for child in item.children) + elif isinstance(item, (list, tuple)): + converted = [convert_to_meta(child) for child in item] + return type(item)(converted) else: return item diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 6da1e17db361b..5a00d9265c088 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -13,14 +13,12 @@ OpSchema, OpSpec, OpStrategy, - OutputSharding, PlacementList, RuntimeSchemaInfo, StrategyType, TensorMeta, TupleStrategy, ) -from torch.distributed.tensor._ops._common_rules import pointwise_rule from torch.distributed.tensor._ops.single_dim_strategy import ( _ShardingPlaceholder, register_single_dim_strategy, @@ -32,7 +30,6 @@ is_tensor_partial, normalize_dim, register_op_strategy, - register_prop_rule, shift_shard_dims_after_insert, shift_shard_dims_after_remove, ) @@ -102,9 +99,41 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: )(propagate_single_input_strategy) -register_op_strategy( - aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(propagate_single_input_strategy) +def _partial_needs_reduce_for_dtype_cast( + reduce_op: str, + src_dtype: torch.dtype, + target_dtype: torch.dtype | None, +) -> bool: + """Return True when reduce_op does not commute with the dtype cast.""" + if target_dtype is None or src_dtype == target_dtype: + return False + if target_dtype == torch.bool: + return True + if reduce_op in ("max", "min"): + return False + return src_dtype.is_floating_point and not target_dtype.is_floating_point + + +@register_single_dim_strategy( + aten._to_copy.default, + schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]), + allow_unbacked_sharding=True, + allow_uneven_sharding=True, +) +def _to_copy_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + input_meta = cast(TensorMeta, args_schema[0]) + src_dtype = input_meta.dtype + target_dtype = cast(torch.dtype | None, kwargs_schema.get("dtype", None)) + + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(len(input_meta.shape)): + strategies.append([_ShardingPlaceholder(dim), _ShardingPlaceholder(dim)]) + for reduce_op in Partial.ALL_REDUCE_OPS: + if not _partial_needs_reduce_for_dtype_cast(reduce_op, src_dtype, target_dtype): + strategies.append([Partial(reduce_op), Partial(reduce_op)]) + return strategies @register_op_strategy( @@ -962,6 +991,87 @@ def index_select_single_dim_strategy( return strategies +@register_single_dim_strategy( + aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True) +) +def index_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + values_meta, multi_indices_meta = args_schema + if not isinstance(values_meta, TensorMeta): + raise AssertionError(f"Expected TensorMeta, got {type(values_meta)}") + if not isinstance(multi_indices_meta, (list, tuple)): + raise AssertionError(f"Expected list or tuple, got {type(multi_indices_meta)}") + + indexed_dims = [i for i, idx in enumerate(multi_indices_meta) if idx is not None] + non_indexed_dims = [ + i for i in range(len(values_meta.shape)) if i not in set(indexed_dims) + ] + + index_metas = [idx for idx in multi_indices_meta if idx is not None] + if not all(isinstance(m, TensorMeta) for m in index_metas): + raise AssertionError("Expected all index metas to be TensorMeta") + broadcast_ndim = max(len(m.shape) for m in index_metas) + num_indices = len(indexed_dims) + + # Determine where index output dims are inserted in the result + all_consecutive = all( + indexed_dims[i + 1] - indexed_dims[i] == 1 for i in range(len(indexed_dims) - 1) + ) + insert_dim = indexed_dims[0] if all_consecutive else 0 + + def values_dim_to_output_dim(d: int) -> int: + if d < insert_dim: + return d + return d + broadcast_ndim - sum(1 for idx_dim in indexed_dims if d > idx_dim) + + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + + # Shard values on a non-indexed dim, all indices replicated + for d in non_indexed_dims: + out_dim = values_dim_to_output_dim(d) + rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(out_dim)] + rule.append(_ShardingPlaceholder(d)) + rule.extend([Replicate()] * num_indices) + strategies.append(rule) + + # Shard indices on the same broadcast dim. Each index tensor may + # have a different ndim, so we map broadcast dim → tensor dim via + # left-padding. Tensors with size 1 on that dim are replicated + # (broadcast semantics). + for bd in range(broadcast_ndim): + per_tensor: list[tuple[int, int]] = [] # (tensor_dim, size) + for m in index_metas: + offset = broadcast_ndim - len(m.shape) + if bd < offset: + per_tensor.append((-1, 1)) # implicit broadcast + else: + td = bd - offset + per_tensor.append((td, m.shape[td])) + if all(s == 1 for _, s in per_tensor): + continue # all broadcast-only, skip + out_dim = bd + insert_dim + rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(out_dim)] + rule.append(Replicate()) + for td, s in per_tensor: + if s > 1: + rule.append(_ShardingPlaceholder(td)) + else: + rule.append(Replicate()) + strategies.append(rule) + + # Partial passthrough from values + for reduce_op in Partial.LINEAR_REDUCE_OPS: + rule: list[Placement | _ShardingPlaceholder] = [ + Partial(reduce_op), + Partial(reduce_op), + ] + rule.extend([Replicate()] * num_indices) + strategies.append(rule) + + return strategies + + @register_single_dim_strategy( [aten.index_put.default, aten._index_put_impl_.default], schema_info=RuntimeSchemaInfo(needs_pytree=True), @@ -1047,135 +1157,6 @@ def index_put_single_dim_strategy( return strategies -@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) -def prop_index(op_schema: OpSchema) -> OutputSharding: - """ - Expect replicated on the first input; _mostly_ pointwise on the second input. - - TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. - """ - # Current sharding constraints: - # For values: - # 1. We currently require that the dimension of values_spec be replicated or partial - # if they are being indexed on. - # 2. Other dimensions of values_spec can remain sharded if they are so. - # For indices: - # Indices can be either sharded or replicated. All index tensors need to be sharded - # in a compatible way, following the pointwise rule (including resolving Partial - # into either sharded or replicated) - - values_spec, multi_indices_spec = op_schema.args_schema - if not isinstance(values_spec, DTensorSpec): - raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}") - if not isinstance(multi_indices_spec, list): - raise AssertionError(f"Expected list, got {type(multi_indices_spec)}") - multi_indices_spec = cast(list[DTensorSpec | None], multi_indices_spec) - valid_indices_spec: list[tuple[int, DTensorSpec]] = [ - (i, a) for i, a in enumerate(multi_indices_spec) if a is not None - ] - - # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. - # Here, we piggyback on the pointwise sharding rule for indices. - indices_out = pointwise_rule( - OpSchema( - op=op_schema.op, - args_schema=tuple(v[1] for v in valid_indices_spec), - kwargs_schema={}, - ) - ) - need_reshard_on_indices = indices_out.output_spec is None - - if not need_reshard_on_indices: - # this means that our inputs are already sharded properly and we will use that as our indices_spec - if not isinstance(indices_out.output_spec, DTensorSpec): - raise AssertionError( - f"Expected DTensorSpec, got {type(indices_out.output_spec)}" - ) - indices_spec: DTensorSpec = indices_out.output_spec - else: - if indices_out.redistribute_schema is None: - raise AssertionError("redistribute_schema should not be None") - valid_indices_suggestion = indices_out.redistribute_schema - for i, v in enumerate(valid_indices_suggestion.args_spec): - multi_indices_spec[valid_indices_spec[i][0]] = v - # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then - # use that to compute our ideal values_spec - indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec - if not isinstance(indices_output_spec, DTensorSpec): - raise AssertionError( - f"Expected DTensorSpec, got {type(indices_output_spec)}" - ) - indices_spec = indices_output_spec - - lookup_dims = {v[0] for v in valid_indices_spec} - - need_reshard_on_values = tuple( - (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) - for vp, ip in zip(values_spec.placements, indices_spec.placements) - ) - - if not need_reshard_on_indices and not any(need_reshard_on_values): - value_placements = values_spec.placements - - all_dims_consecutive = all( - b[0] - a[0] == 1 - for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) - ) - if all_dims_consecutive: - # if all index vectors are consecutives, insert at the dimension of the first index - insert_dim: int = valid_indices_spec[0][0] - else: - # else, insert on the first dimension - insert_dim = 0 - - def place(vp: Placement, ip: Placement) -> Placement: - if isinstance(vp, Shard): - return Shard( - vp.dim - if vp.dim < insert_dim - # accounts for the offset in output dimensions - else vp.dim - + indices_spec.ndim - - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) - ) - if isinstance(ip, Shard): - return Shard(ip.dim + insert_dim) - # Partial or Replicated - return vp - - value_placements = tuple( - place(vp, ip) - for vp, ip in zip(values_spec.placements, indices_spec.placements) - ) - result = OutputSharding( - output_spec=DTensorSpec( - mesh=values_spec.mesh, - placements=value_placements, - ) - ) - return result - else: - result = OutputSharding( - output_spec=None, - redistribute_schema=OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - mesh=values_spec.mesh, - placements=tuple( - Replicate() if need_reshard_on_values[i] else v - for i, v in enumerate(values_spec.placements) - ), - tensor_meta=values_spec.tensor_meta, - ), - multi_indices_spec, - ), - kwargs_schema=op_schema.kwargs_schema, - ), - ) - return result - - @register_op_strategy( [ aten.split.Tensor, diff --git a/torch/distributed/tensor/_ops/single_dim_strategy.py b/torch/distributed/tensor/_ops/single_dim_strategy.py index e2127c29950c4..c64b89a7c156f 100644 --- a/torch/distributed/tensor/_ops/single_dim_strategy.py +++ b/torch/distributed/tensor/_ops/single_dim_strategy.py @@ -173,6 +173,9 @@ def _update_placements(obj: Any): elif isinstance(obj, TupleStrategy): for child in obj.children: _update_placements(child) + elif isinstance(obj, (list, tuple)): + for child in obj: + _update_placements(child) for obj in op_schema.args_schema: _update_placements(obj) @@ -185,17 +188,21 @@ def _update_placements(obj: Any): def _get_num_tensor_inputs(op_schema: OpSchema) -> int: num_inputs = 0 - for obj in op_schema.args_schema: + + def _count(obj: Any) -> int: if isinstance(obj, OpStrategy): - num_inputs += 1 + return 1 elif isinstance(obj, TupleStrategy): - num_inputs += len(obj.children) + return len(obj.children) + elif isinstance(obj, (list, tuple)): + return sum(_count(child) for child in obj) + return 0 + + for obj in op_schema.args_schema: + num_inputs += _count(obj) # Also count tensor kwargs (e.g., "out" for out-variant ops) for obj in op_schema.kwargs_schema.values(): - if isinstance(obj, OpStrategy): - num_inputs += 1 - elif isinstance(obj, TupleStrategy): - num_inputs += len(obj.children) + num_inputs += _count(obj) return num_inputs diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 4d016e299197d..8455ebbc9af44 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import _get_device_handle, DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor.placement_types import _StridedShard, Shard +from torch.types import IntLikeType logger = getLogger(__name__) @@ -391,8 +392,8 @@ def _compute_rng_offsets(self, spec: DTensorSpec) -> tuple[int, int]: return start_offset_incr, end_offset_incr def _calc_shard_linear_idx( - self, shard_coord: list[int], shard_size: list[int] - ) -> int: + self, shard_coord: Sequence[IntLikeType], shard_size: Sequence[IntLikeType] + ) -> IntLikeType: return _calc_shard_linear_idx(shard_coord, shard_size) @@ -411,8 +412,8 @@ def _calc_first_shard_size(spec: DTensorSpec) -> list[int]: def _calc_shard_info( - mesh_coordinate: Sequence[int], spec: DTensorSpec -) -> tuple[list[int], list[int]]: + mesh_coordinate: Sequence[IntLikeType], spec: DTensorSpec +) -> tuple[list[IntLikeType], list[IntLikeType]]: mesh = spec.mesh # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP # case. Replace the custom logic with dim_map once we support it. @@ -436,10 +437,12 @@ def _calc_shard_info( raise AssertionError mesh_size = mesh.shape shard_idx_by_dim = [] - total_num_shards_by_dim = [] # total number of shards on each tensor dim + total_num_shards_by_dim: list[ + IntLikeType + ] = [] # total number of shards on each tensor dim for mesh_dim in dim_map: - shard_idx = 0 - total_num_shards = 1 + shard_idx: IntLikeType = 0 + total_num_shards: IntLikeType = 1 # the tensor dim is sharded on more than 1 mesh dim if isinstance(mesh_dim, list): rank_coord = [mesh_coordinate[d] for d in mesh_dim] @@ -454,10 +457,12 @@ def _calc_shard_info( return shard_idx_by_dim, total_num_shards_by_dim -def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: +def _calc_shard_linear_idx( + shard_coord: Sequence[IntLikeType], shard_size: Sequence[IntLikeType] +) -> IntLikeType: # compute shard linear index - shard_linear_idx = 0 - shard_coord_stride = 1 + shard_linear_idx: IntLikeType = 0 + shard_coord_stride: IntLikeType = 1 for idx, size in zip(reversed(shard_coord), reversed(shard_size)): shard_linear_idx += idx * shard_coord_stride shard_coord_stride *= size diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 1c7cb2a485c78..25417cbe0356e 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -31,6 +31,7 @@ Replicate, Shard, ) +from torch.types import IntLikeType from torch.utils._debug_mode import get_active_debug_mode @@ -144,7 +145,7 @@ class _TransformInfo: mesh_dim: int src_dst_placements: tuple[Placement, Placement] # logical_shape on this mesh dimension - logical_shape: list[int] + logical_shape: Sequence[IntLikeType] def __post_init__(self): if self.mesh_dim < 0: @@ -1176,8 +1177,8 @@ def get_logical_shape( src_state: "DTensorRedistributePlanner.DistState", mesh_dim: int, full_tensor_shape: tuple[int, ...], - ) -> list[int]: - new_logical_shape = list(full_tensor_shape) + ) -> list[IntLikeType]: + new_logical_shape: list[IntLikeType] = list(full_tensor_shape) for entry in src_state.tensor_dim_to_mesh_dim: tensor_dim = entry.tensor_dim mesh_dims = entry.mesh_dims diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index de0b7ba08b717..a85b19130f69f 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -586,17 +586,23 @@ def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema: def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): return OpStrategy([OpSpec(spec)]) - elif ( - isinstance(spec, (list, tuple)) - and len(spec) > 0 - and isinstance(spec[0], DTensorSpec) - ): - # tensor list create tuple strategy - tuple_strategy = [spec_to_strategy(s) for s in spec] - tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) - return TupleStrategy( - tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy - ) + elif isinstance(spec, (list, tuple)) and len(spec) > 0: + if all(isinstance(s, DTensorSpec) for s in spec): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) + if isinstance(spec, tuple) + else tuple_strategy + ) + elif any(isinstance(s, DTensorSpec) for s in spec): + # mixed list (e.g. [DTensorSpec, None, DTensorSpec]) for + # ops like aten.index.Tensor; keep as list so pytree + # flattening can extract OpStrategy items + return [spec_to_strategy(s) for s in spec] + else: + return spec else: return spec diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 1e717569986e9..53b96e9051dd4 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -2,6 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import functools +from collections.abc import Sequence from dataclasses import dataclass, field from typing import cast, TypeVar @@ -21,6 +22,7 @@ unpad_tensor, ) from torch.distributed.tensor._ops._mask_buffer import MaskBuffer +from torch.types import IntLikeType __all__ = ["Placement", "Shard", "Replicate", "Partial"] @@ -211,7 +213,7 @@ def _custom_chunk( @staticmethod @maybe_run_for_local_tensor def local_shard_size_and_offset( - curr_local_size: int, + curr_local_size: IntLikeType, num_chunks: int, rank: _RankTypeT, ) -> tuple[_RankTypeT, _RankTypeT]: @@ -392,7 +394,7 @@ def _reduce_shard_tensor( def _maybe_pad_tensor( self, local_tensor: torch.Tensor, - logical_dim_size: int, + logical_dim_size: IntLikeType, num_chunks: int, ) -> torch.Tensor: from torch.fx.experimental.symbolic_shapes import guard_or_true @@ -414,7 +416,7 @@ def _maybe_pad_tensor( def _maybe_unpad_tensor( self, local_tensor: torch.Tensor, - logical_dim_size: int, + logical_dim_size: IntLikeType, num_chunks: int, ) -> torch.Tensor: from torch.fx.experimental.symbolic_shapes import guard_or_true @@ -434,7 +436,7 @@ def _to_replicate_tensor( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], ) -> torch.Tensor: """ This function all_gather all shards and return a tensor that @@ -462,7 +464,7 @@ def _replicate_to_shard( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - shard_index: int, + shard_index: IntLikeType, ) -> torch.Tensor: """ transform from replicated tensor to a sharded tensor on @@ -489,11 +491,11 @@ def _get_shard_pad_size( @staticmethod def _compute_padding_info( - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], num_chunks: int, old_shard_dim: int, new_shard_dim: int, - ) -> tuple[bool, int, int, bool, int, int]: + ) -> tuple[bool, IntLikeType, int, bool, IntLikeType, int]: from torch.fx.experimental.symbolic_shapes import guard_or_true results = [] @@ -508,7 +510,7 @@ def _compute_padding_info( @staticmethod @maybe_run_for_local_tensor def _pad_for_new_shard_dim( - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], local_tensor: torch.Tensor, num_chunks: int, old_shard_dim: int, @@ -543,7 +545,7 @@ def _pad_for_new_shard_dim( @staticmethod @maybe_run_for_local_tensor def _unpad_for_new_shard_dim( - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], local_tensor: torch.Tensor, num_chunks: int, old_shard_dim: int, @@ -582,7 +584,7 @@ def _to_new_shard_dim( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], new_shard_dim: int, ) -> torch.Tensor: """ @@ -857,7 +859,7 @@ def _select_split_tensor( self, tensor: torch.Tensor, num_chunks: int, - index: int, + index: IntLikeType, *, with_padding: bool = True, contiguous: bool = True, @@ -891,7 +893,7 @@ def _to_replicate_tensor( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], ) -> torch.Tensor: """ Replay the replicate-to-shard process to understand how to stitch shards back. @@ -986,6 +988,7 @@ def _to_replicate_tensor( # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [logical_dim_size] + # pyrefly: ignore [no-matching-overload] indices_tensor = torch.arange( logical_dim_size, device=local_tensor.device ).view(shape) @@ -1050,7 +1053,7 @@ def _replicate_to_strided_shard( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - shard_index: int, + shard_index: IntLikeType, ) -> torch.Tensor: """ Transform from replicated tensor to a strided-sharded tensor on the current rank. @@ -1097,7 +1100,7 @@ def _local_shard_size_and_offset( @maybe_run_for_local_tensor def local_shard_size_and_offset( self, - curr_local_size: int, + curr_local_size: IntLikeType, num_chunks: int, rank: RankType, return_first_offset: bool = True, @@ -1127,6 +1130,7 @@ def local_shard_size_and_offset( # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] + # pyrefly: ignore [no-matching-overload] indices_tensor = torch.arange( curr_local_size, ).view(shape) @@ -1384,7 +1388,9 @@ def __init__( @staticmethod @maybe_run_for_local_tensor def _mask_tensor( - tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int + tensor: torch.Tensor, + local_offset_on_dim: IntLikeType, + local_shard_size: IntLikeType, ) -> tuple[torch.Tensor, torch.Tensor]: # Build the input mask and save it for the current partial placement # this is so that the output of embedding op can reuse the same partial @@ -1393,8 +1399,10 @@ def _mask_tensor( tensor >= local_offset_on_dim + local_shard_size ) # mask the input tensor + # pyrefly: ignore [unsupported-operation] masked_tensor = tensor.clone() - local_offset_on_dim masked_tensor[mask] = 0 + # pyrefly: ignore [bad-return] return mask, masked_tensor def _partition_value( diff --git a/torch/export/_leakage_detection_utils.py b/torch/export/_leakage_detection_utils.py index fe211e1dc079c..722a756431ea1 100644 --- a/torch/export/_leakage_detection_utils.py +++ b/torch/export/_leakage_detection_utils.py @@ -2,8 +2,9 @@ import types import typing import weakref +from typing_extensions import TypeIs -import torch +from torch.fx.experimental.symbolic_shapes import TrackedFake """ @@ -37,8 +38,8 @@ def _is_globals_or_locals(obj: typing.Any) -> bool: return obj is globals() or obj is locals() -def _is_tracked_fake(obj: typing.Any) -> bool: - return isinstance(obj, torch.fx.experimental.symbolic_shapes.TrackedFake) +def _is_tracked_fake(obj: typing.Any) -> TypeIs[TrackedFake]: + return isinstance(obj, TrackedFake) def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool: diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 8d7cb8972cec6..f8f413f9cfe60 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -7,6 +7,7 @@ import zipfile from dataclasses import dataclass from typing import Any, IO, TYPE_CHECKING, TypeAlias +from typing_extensions import TypeIs import torch import torch.utils._pytree as pytree @@ -331,7 +332,7 @@ def _package_aoti_files( logger.debug(weights_config) -def _is_fake_tensor(t: torch.Tensor) -> bool: +def _is_fake_tensor(t: torch.Tensor) -> TypeIs[FakeTensor]: return isinstance(t, FakeTensor) diff --git a/torch/fx/experimental/_size_hinting.py b/torch/fx/experimental/_size_hinting.py index 24090b35b5b71..80791e4dccf64 100644 --- a/torch/fx/experimental/_size_hinting.py +++ b/torch/fx/experimental/_size_hinting.py @@ -418,7 +418,16 @@ def _optimization_hint_base( sym_fallback = min(sym_fallback, int(vr.upper)) size_dict[s] = sym_fallback - final_result = expr.subs(size_dict) + try: + final_result = expr.subs(size_dict) + except ZeroDivisionError: + # Expressions like ModularIndexing(x, u1, 4) crash during subs() + # when u1 is substituted with 0, because sympy eagerly evaluates + # (x // 0) % 4. This can happen when an unbacked symbol with + # var_to_range lower=0 is used as a divisor (e.g. from + # _dynamic_reshape_indexer) and the fallback also maps to 0. + # Return fallback in that case. + return fallback if fallback is not None else 0 final_result = _maybe_realize_expr(final_result, fallback) if final_result is None: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index ba00b18c0774c..58dd5c93c15e9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2318,6 +2318,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext): shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = field( default_factory=dict ) + excluded_sizes: tuple[int | None, ...] | None = None @dataclass(frozen=True, slots=True) @@ -3936,6 +3937,11 @@ def _init( self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {} # Set holds a % b expressions that evaluate to 0. self.divisible: set[sympy.Expr] = set() + # Exclusion constraints from automatic_dynamic transitions. + # Each (symbol, excluded_value) pair represents one dim/scalar that + # transitioned static → dynamic. All pairs are combined into a single + # Or(Ne(...), ...) guard in produce_guards_verbose. + self.exclusion_constraints: list[tuple[sympy.Symbol, int]] = [] # Set that holds "size-like" symbols. When we perform # "size-oblivious" tests, these can be assumed to be >= 2. self.size_like: set[sympy.Symbol] = set() @@ -4825,6 +4831,27 @@ def _create_symbolic_sizes_strides_storage_offset( size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( ex_size, source, symbolic_context, hint_overrides=hint_overrides ) + # Record tensor exclusion constraints for stable graph selection. + # The ndim check guards against stale excluded_sizes from graph + # breaks where the resumed tensor may have different dimensionality. + # Skip dims with hint overrides: the overridden hint in + # backed_var_to_val would mismatch the excluded value, causing the + # not-all check in produce_guards_verbose to emit a guard that + # immediately fails. + excluded_sizes = getattr(symbolic_context, "excluded_sizes", None) + if ( + excluded_sizes + and len(excluded_sizes) == dim + and any(v is not None for v in excluded_sizes) + ): + for i in range(dim): + ev = excluded_sizes[i] + if ( + ev is not None + and isinstance(size[i], sympy.Symbol) + and i not in (hint_overrides or {}) + ): + self._record_exclusion_constraint(size[i], ev) stride = self._compute_symbolic_stride( source, size, @@ -5029,17 +5056,28 @@ def create_symfloatnode( out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) return out + @record_shapeenv_event() + def _record_exclusion_constraint(self, sym: sympy.Symbol, val: int) -> None: + self.exclusion_constraints.append((sym, val)) + @record_shapeenv_event() def create_unspecified_symint_and_symbol( - self, value: int, source: Source, dynamic_dim: DimDynamic + self, + value: int, + source: Source, + dynamic_dim: DimDynamic, + excluded_value: int | None = None, ) -> IntLikeType: """Create a SymInt wrapping a new unspecified symbol""" + sym = self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ) + if excluded_value is not None: + self._record_exclusion_constraint(sym, excluded_value) return self.create_symintnode( - self.create_unspecified_symbol( - value, - source=source, - dynamic_dim=dynamic_dim, - ), + sym, hint=value, source=source, ) @@ -6321,6 +6359,93 @@ def issue_guard(guard: ShapeGuard) -> None: else: raise NotImplementedError(f"Unimplemented for lang: {lang}") + # Exclusion guard for stable graph selection with automatic dynamic. + # + # When automatic_dynamic promotes a static dim to dynamic, the new + # (more general) graph is inserted *before* the old (specialized) graph + # in the guard cache. Without an exclusion guard, inputs that exactly + # match the old graph's static sizes would be captured by the new + # dynamic graph instead, violating the invariant "once an input is + # served by graph X it is always served by graph X". This condition + # is true iff there is no branching on dynamic shapes. + # + # Soundness argument (cache-flip / LIFO order): + # Graph_new sits before Graph_old in the cache. Graph_old accepts + # only inputs whose sizes match its static constraints exactly. + # Graph_new must therefore reject exactly that set of inputs so they + # fall through to Graph_old. The excluded values are the static + # sizes from Graph_old, so the guard + # Or(Ne(s0, v0), Ne(s1, v1), ...) + # passes iff at least one dim differs from the old sizes — i.e. the + # input does NOT fully match Graph_old. Conversely, when every dim + # matches the old sizes the guard fails and the input falls through + # to Graph_old, which is guaranteed to accept it. + # + # Theorem: For graphs G0, ..., Gn compiled via progressive dynamism + # (one dim per step), each input is accepted by at most one graph. + # + # Setup: G0 is all-static with shape S. Gk is created by making + # dim d_k dynamic, with exclusion guard d_k != S[d_k]. + # + # Proof by induction on n: + # + # Base case (n=0): Only G0, all-static. Trivially unique. + # + # Inductive step: Assume the property holds for G0, ..., G_{n-1}. + # We add Gn with newly-dynamic dim d_n and exclusion d_n != S[d_n]. + # + # For any input X that passes Gn's shape guards, exactly one of: + # + # Case A — exclusion passes (X[d_n] != S[d_n]): + # Dim d_n is static in all G0, ..., G_{n-1} with value S[d_n], + # so X fails all prior graphs on that dim. Only Gn accepts X. + # + # Case B — exclusion rejects (X[d_n] == S[d_n]): + # X matches Gn's shape guards on all other dims, and matches + # the static value for d_n. So X satisfies G_{n-1}'s shape + # guards. By the inductive hypothesis, exactly one of + # G0, ..., G_{n-1} accepts X. Gn rejects X. + # + # Corollary: Evaluation order does not affect correctness. + # + # All exclusion pairs across all tensors and scalars are flattened + # into a single list — each pair is just (symbol, excluded_int), + # and the multi-tensor case is the same logic as multi-dim within + # one tensor. The combined Or rejects only when ALL pairs match + # simultaneously, which is the exact condition for Graph_old to + # accept. If the current concrete values already match every + # excluded value the guard is skipped (it would fail on creation). + import torch._dynamo.config as dynamo_config + + if ( + dynamo_config.automatic_dynamic_exclusion_guard + and not dynamo_config.enable_compiler_collectives + and self.exclusion_constraints + ): + all_pairs = [ + (sym, val) + for sym, val in self.exclusion_constraints + if symbol_to_source.get(sym) + ] + if all_pairs and not all( + self.backed_var_to_val.get(sym) == val for sym, val in all_pairs + ): + if len(all_pairs) == 1: + excl_expr = sympy.Ne( + all_pairs[0][0], all_pairs[0][1], evaluate=False + ) + else: + excl_expr = sympy.Or( + *[sympy.Ne(sym, val, evaluate=False) for sym, val in all_pairs] + ) + for exprs, printer, lang in zip(all_exprs, printers, langs): + guard_expr = printer.doprint(excl_expr) + if lang == "verbose_python": + guard_expr = ( + f"{guard_expr} # exclusion guard for automatic dynamic" + ) + exprs.append(guard_expr) + if constraint_violations: warn_msgs: list[str] = [] error_msgs: list[str] = [] diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h index 280e2d984dcb8..ce43ce6866cd9 100644 --- a/torch/headeronly/core/ScalarType.h +++ b/torch/headeronly/core/ScalarType.h @@ -88,7 +88,6 @@ struct dummy_int1_7_t {}; _(float, Float) \ _(double, Double) \ _(c10::complex, ComplexHalf) \ - _(c10::complex, BComplex32) \ _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) \ _(bool, Bool) \ @@ -147,8 +146,7 @@ struct dummy_int1_7_t {}; _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ - _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ \ - _(c10::complex, BComplex32) /* 46 */ + _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ // NB: despite its generic sounding name, the macros that don't take _AND // are mostly only used by tensorexpr diff --git a/torch/headeronly/util/complex.h b/torch/headeronly/util/complex.h index 414307669524f..733a22d5dbb7a 100644 --- a/torch/headeronly/util/complex.h +++ b/torch/headeronly/util/complex.h @@ -3,7 +3,6 @@ #include #include -#include #include #if defined(__CUDACC__) || defined(__HIPCC__) @@ -589,60 +588,6 @@ struct alignas(4) complex { } }; -template <> -struct alignas(4) complex { - BFloat16 real_; - BFloat16 imag_; - - // Constructors - complex() = default; - // BFloat16 constructor is not constexpr so the following constructor can't - // be constexpr - C10_HOST_DEVICE explicit inline complex( - const BFloat16& real, - const BFloat16& imag) - : real_(real), imag_(imag) {} - C10_HOST_DEVICE inline complex(const c10::complex& value) - : real_(value.real()), imag_(value.imag()) {} - - // Conversion operator - inline C10_HOST_DEVICE operator c10::complex() const { - return {real_, imag_}; - } - - constexpr C10_HOST_DEVICE BFloat16 real() const { - return real_; - } - constexpr C10_HOST_DEVICE BFloat16 imag() const { - return imag_; - } - - C10_HOST_DEVICE complex& operator+=( - const complex& other) { - real_ = static_cast(real_) + static_cast(other.real_); - imag_ = static_cast(imag_) + static_cast(other.imag_); - return *this; - } - - C10_HOST_DEVICE complex& operator-=( - const complex& other) { - real_ = static_cast(real_) - static_cast(other.real_); - imag_ = static_cast(imag_) - static_cast(other.imag_); - return *this; - } - - C10_HOST_DEVICE complex& operator*=( - const complex& other) { - auto a = static_cast(real_); - auto b = static_cast(imag_); - auto c = static_cast(other.real()); - auto d = static_cast(other.imag()); - real_ = a * c - b * d; - imag_ = a * d + b * c; - return *this; - } -}; - } // namespace c10 HIDDEN_NAMESPACE_BEGIN(torch, headeronly) diff --git a/torch/nn/attention/_fa4.py b/torch/nn/attention/_fa4.py index f0ea99a463532..c2786d9dceefe 100644 --- a/torch/nn/attention/_fa4.py +++ b/torch/nn/attention/_fa4.py @@ -67,6 +67,11 @@ def _fa4_import_module(module_path: str) -> ModuleType: def _fa4_register_kernels() -> Library: lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901 lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA") + lib.impl( + "_flash_attention_forward_no_dropout_inplace", + _fa4_flash_attention_forward_no_dropout_inplace_impl, + "CUDA", + ) lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA") lib.impl( "_scaled_dot_product_flash_attention", @@ -116,6 +121,8 @@ def _fa4_forward_support_error( alibi_slopes: torch.Tensor | None, seqused_k: torch.Tensor | None, cum_seq_q: torch.Tensor | None, + block_table: torch.Tensor | None = None, + num_splits: int | None = None, ) -> str | None: if dropout_p != 0.0: return "dropout_p must be 0" @@ -128,6 +135,11 @@ def _fa4_forward_support_error( return "seqused_k must be int32" if not seqused_k.is_cuda: return "seqused_k must be CUDA" + major = _get_device_major(query.device) + if block_table is not None and major != 10: + return f"paged KV (block_table) not supported on SM {major}0" + if num_splits is not None and num_splits > 1 and major != 10: + return f"SplitKV (num_splits > 1) not supported on SM {major}0" error = _fa4_common_support_error( query, (query, key, value), @@ -149,13 +161,9 @@ def _fa4_backward_support_error( logsumexp: torch.Tensor, dropout_p: float, cum_seq_q: torch.Tensor | None, - window_size_left: int | None, - window_size_right: int | None, ) -> str | None: if dropout_p != 0.0: return "dropout_p must be 0" - if window_size_left is not None or window_size_right is not None: - return "windowed attention not supported" error = _fa4_common_support_error( query, (grad_out, query, key, value, out, logsumexp), @@ -167,6 +175,11 @@ def _fa4_backward_support_error( return None +def _aten_to_fa4_window_size(val: int | None) -> int | None: + """need to convert -1 to None for FA4""" + return None if val == -1 else val + + Ts = TypeVarTuple("Ts") @@ -180,12 +193,16 @@ def _fa4_run_forward( value: torch.Tensor, cu_seq_q: torch.Tensor | None, cu_seq_k: torch.Tensor | None, + max_q: int | None, + max_k: int | None, scale: float | None, is_causal: bool, window_size_left: int | None, window_size_right: int | None, seqused_k: torch.Tensor | None, out: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + num_splits: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if _FA4_MODULE_PATH is None: raise RuntimeError("FA4 not registered") @@ -194,15 +211,18 @@ def _fa4_run_forward( kwargs: dict[str, Any] = { "softmax_scale": scale, "causal": is_causal, - "window_size_left": window_size_left, - "window_size_right": window_size_right, + "window_size_left": _aten_to_fa4_window_size(window_size_left), + "window_size_right": _aten_to_fa4_window_size(window_size_right), "return_lse": True, "cu_seqlens_q": cu_seq_q, "cu_seqlens_k": cu_seq_k, + "max_seqlen_q": max_q, + "max_seqlen_k": max_k, "seqused_k": seqused_k.contiguous() if seqused_k is not None else None, + "page_table": block_table, + "num_splits": num_splits or 1, + "out": out, } - if out is not None: - kwargs["out"] = out out, lse = module._flash_attn_fwd(query, key, value, **kwargs) return out, lse.contiguous() @@ -218,6 +238,8 @@ def _fa4_run_backward( cu_seq_k: torch.Tensor | None, scale: float | None, is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, deterministic: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if _FA4_MODULE_PATH is None: @@ -232,6 +254,8 @@ def _fa4_run_backward( logsumexp.contiguous(), softmax_scale=scale, causal=is_causal, + window_size_left=_aten_to_fa4_window_size(window_size_left), + window_size_right=_aten_to_fa4_window_size(window_size_right), cu_seqlens_q=cu_seq_q, cu_seqlens_k=cu_seq_k, deterministic=deterministic, @@ -257,6 +281,9 @@ def _fa4_flash_attention_forward_impl( seqused_k: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, out: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + compute_auxiliary: bool = True, + num_splits: int | None = None, ): error = _fa4_forward_support_error( query, @@ -267,6 +294,8 @@ def _fa4_flash_attention_forward_impl( alibi_slopes, seqused_k, cum_seq_q, + block_table, + num_splits, ) if error is not None: raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}") @@ -276,19 +305,73 @@ def _fa4_flash_attention_forward_impl( value, cum_seq_q, cum_seq_k, + max_q, + max_k, scale, is_causal, window_size_left, window_size_right, seqused_k, out, + block_table, + num_splits, ) - rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device) - philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device) - debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + if compute_auxiliary: + rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device) + philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device) + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + else: + rng_state = None + philox_offset = None + debug_mask = None return out, lse, rng_state, philox_offset, debug_mask +def _fa4_flash_attention_forward_no_dropout_inplace_impl( + out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cum_seq_q: torch.Tensor | None, + cum_seq_k: torch.Tensor | None, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + return_debug_mask: bool, + *, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, + seqused_k: torch.Tensor | None = None, + alibi_slopes: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + num_splits: int | None = None, +): + _, lse, _, _, _ = _fa4_flash_attention_forward_impl( + query, + key, + value, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + window_size_left=window_size_left, + window_size_right=window_size_right, + seqused_k=seqused_k, + alibi_slopes=alibi_slopes, + out=out, + block_table=block_table, + compute_auxiliary=False, + num_splits=num_splits, + ) + return lse + + def _fa4_flash_attention_backward_impl( grad_out: torch.Tensor, query: torch.Tensor, @@ -318,8 +401,6 @@ def _fa4_flash_attention_backward_impl( logsumexp, dropout_p, cum_seq_q, - window_size_left, - window_size_right, ) if error is not None: raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}") @@ -335,6 +416,8 @@ def _fa4_flash_attention_backward_impl( cum_seq_k, scale, is_causal, + window_size_left, + window_size_right, deterministic, ) return dq, dk, dv @@ -428,8 +511,6 @@ def _fa4_scaled_dot_product_flash_attention_backward_impl( logsumexp, dropout_p, None, - None, - None, ) if error is not None: raise RuntimeError(f"FA4 SDPA backward unsupported: {error}") diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 91e8de79855a3..e56a747e14190 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -44,6 +44,8 @@ def __init__( track_running_stats: bool = True, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -54,7 +56,10 @@ def __init__( self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) - self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -95,7 +100,8 @@ def reset_parameters(self) -> None: self.reset_running_stats() if self.affine: init.ones_(self.weight) - init.zeros_(self.bias) + if self.bias is not None: + init.zeros_(self.bias) def _check_input_dim(self, input): raise NotImplementedError @@ -103,7 +109,9 @@ def _check_input_dim(self, input): def extra_repr(self): return ( "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " - "track_running_stats={track_running_stats}".format(**self.__dict__) + "bias={use_bias}, track_running_stats={track_running_stats}".format( + **self.__dict__, use_bias=self.bias is not None + ) ) def _load_from_state_dict( @@ -151,10 +159,18 @@ def __init__( track_running_stats: bool = True, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + num_features, + eps, + momentum, + affine, + track_running_stats, + **factory_kwargs, + bias=bias, ) def forward(self, input: Tensor) -> Tensor: @@ -220,11 +236,13 @@ def __init__( track_running_stats=True, device=None, dtype=None, + *, + bias=True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} # pyrefly: ignore [bad-argument-type] super().__init__( - # affine and track_running_stats are hardcoded to False to + # affine, bias and track_running_stats are hardcoded to False to # avoid creating tensors that will soon be overwritten. 0, eps, @@ -232,14 +250,16 @@ def __init__( False, False, **factory_kwargs, + bias=False, ) self.affine = affine self.track_running_stats = track_running_stats if self.affine: # pyrefly: ignore [unexpected-keyword] self.weight = UninitializedParameter(**factory_kwargs) - # pyrefly: ignore [unexpected-keyword] - self.bias = UninitializedParameter(**factory_kwargs) + if bias: + # pyrefly: ignore # bad-argument-type + self.bias = UninitializedParameter(**factory_kwargs) if self.track_running_stats: # pyrefly: ignore [unexpected-keyword] self.running_mean = UninitializedBuffer(**factory_kwargs) @@ -266,10 +286,13 @@ def initialize_parameters(self, input) -> None: # type: ignore[override] raise AssertionError( "self.weight must be an UninitializedParameter" ) - if not isinstance(self.bias, UninitializedParameter): - raise AssertionError("self.bias must be an UninitializedParameter") self.weight.materialize((self.num_features,)) - self.bias.materialize((self.num_features,)) + if self.bias is not None: + if not isinstance(self.bias, UninitializedParameter): + raise AssertionError( + "self.bias must be an UninitializedParameter" + ) + self.bias.materialize((self.num_features,)) if self.track_running_stats: self.running_mean.materialize( # type:ignore[union-attr] (self.num_features,) @@ -335,6 +358,8 @@ class BatchNorm1d(_BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, @@ -381,6 +406,8 @@ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` """ cls_to_become = BatchNorm1d # type: ignore[assignment] @@ -447,6 +474,8 @@ class BatchNorm2d(_BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` @@ -492,6 +521,8 @@ class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` """ cls_to_become = BatchNorm2d # type: ignore[assignment] @@ -558,6 +589,8 @@ class BatchNorm3d(_BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` @@ -603,6 +636,8 @@ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` """ cls_to_become = BatchNorm3d # type: ignore[assignment] @@ -677,6 +712,8 @@ class SyncBatchNorm(_BatchNorm): process_group: synchronization of stats happen within each process group individually. Default behavior is synchronization across the whole world + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, +)` @@ -725,10 +762,18 @@ def __init__( process_group: Any | None = None, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + num_features, + eps, + momentum, + affine, + track_running_stats, + **factory_kwargs, + bias=bias, ) self.process_group = process_group @@ -886,6 +931,7 @@ def convert_sync_batchnorm(cls, module, process_group=None): module.affine, module.track_running_stats, process_group, + bias=module.bias is not None, ) if module.affine: with torch.no_grad(): diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index 058ffb3ed9aa9..76a343afbd0dd 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -28,10 +28,18 @@ def __init__( track_running_stats: bool = False, device=None, dtype=None, + *, + bias: bool = True, # for backward compatibility ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + num_features, + eps, + momentum, + affine, + track_running_stats, + **factory_kwargs, + bias=bias, ) def _check_input_dim(self, input): @@ -174,11 +182,13 @@ class InstanceNorm1d(_InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, L)` or :math:`(C, L)` @@ -218,11 +228,13 @@ class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, L)` or :math:`(C, L)` @@ -290,11 +302,13 @@ class InstanceNorm2d(_InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` @@ -335,11 +349,13 @@ class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` @@ -406,11 +422,13 @@ class InstanceNorm3d(_InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` @@ -451,11 +469,13 @@ class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index c32178af0b82e..cab31782edbde 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -142,9 +142,9 @@ class LayerNorm(Module): eps: a value added to the denominator for numerical stability. Default: 1e-5 elementwise_affine: a boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. + and zeros (for biases). Default: ``True`` bias: If set to ``False``, the layer will not learn an additive bias (only relevant if - :attr:`elementwise_affine` is ``True``). Default: ``True``. + :attr:`elementwise_affine` is ``True``). Default: ``True`` Attributes: weight: the learnable weights of the module of shape @@ -231,8 +231,8 @@ def forward(self, input: Tensor) -> Tensor: def extra_repr(self) -> str: return ( - "{normalized_shape}, eps={eps}, " - "elementwise_affine={elementwise_affine}".format(**self.__dict__) + "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}, " + "bias={use_bias}".format(**self.__dict__, use_bias=self.bias is not None) ) @@ -263,7 +263,9 @@ class GroupNorm(Module): eps: a value added to the denominator for numerical stability. Default: 1e-5 affine: a boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. + and zeros (for biases). Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` @@ -296,6 +298,8 @@ def __init__( affine: bool = True, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -310,7 +314,10 @@ def __init__( self.affine = affine if self.affine: self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) - self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) + else: + self.register_parameter("bias", None) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -320,14 +327,16 @@ def __init__( def reset_parameters(self) -> None: if self.affine: init.ones_(self.weight) - init.zeros_(self.bias) + if self.bias is not None: + init.zeros_(self.bias) def forward(self, input: Tensor) -> Tensor: return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) def extra_repr(self) -> str: - return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format( - **self.__dict__ + return ( + "{num_groups}, {num_channels}, eps={eps}, affine={affine}, " + "bias={use_bias}".format(**self.__dict__, use_bias=self.bias is not None) ) diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index bd32799bc4e2b..ca434fa213f79 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -4,11 +4,14 @@ from __future__ import annotations from typing import Any +from typing_extensions import TypeIs import torch -def is_torch_symbolic_type(value: Any) -> bool: +def is_torch_symbolic_type( + value: Any, +) -> TypeIs[torch.SymBool | torch.SymInt | torch.SymFloat]: return isinstance(value, (torch.SymBool, torch.SymInt, torch.SymFloat)) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py index d52e2a6ee9249..0e00a7a5cd263 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py @@ -110,7 +110,7 @@ NoReturn, TypeVar as _TypeVar, ) -from typing_extensions import ParamSpec as _ParamSpec +from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs import torch import torch._C._onnx as _C_onnx @@ -561,7 +561,7 @@ def _is_none(x: Any) -> bool: return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) -def _is_value(x: Any) -> bool: +def _is_value(x: Any) -> _TypeIs[_C.Value]: return isinstance(x, _C.Value) diff --git a/torch/storage.py b/torch/storage.py index 1ac246e7476da..9bb671c0bd510 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -552,7 +552,6 @@ def _new_dtypes(): torch.bits2x4, torch.bits4x2, torch.complex32, - torch.bcomplex32, torch.uint16, torch.uint32, torch.uint64, diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index ef60dd19f7c3e..2605411ed7c9d 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -18,37 +18,6 @@ HAS_NUMPY = False np = None # type: ignore[assignment] -_HAS_DTENSOR = torch.distributed.is_available() - - -def _unwrap_dtensor_for_comparison(actual, expected): - """Handle DTensor inputs for assertEqual/assert_close.""" - if not _HAS_DTENSOR: - return actual, expected - from torch.distributed.tensor import DTensor - - actual_dt = isinstance(actual, DTensor) - expected_dt = isinstance(expected, DTensor) - if actual_dt and expected_dt: - if actual.placements != expected.placements: - raise AssertionError( - f"DTensor placements do not match: " - f"{actual.placements} != {expected.placements}" - ) - if actual.device_mesh != expected.device_mesh: - raise AssertionError( - f"DTensor device meshes do not match: " - f"{actual.device_mesh} != {expected.device_mesh}" - ) - return actual.to_local(), expected.to_local() - elif actual_dt != expected_dt: - raise TypeError( - "Comparing a DTensor to a non-DTensor is ambiguous. " - "Call .full_tensor() to compare the full logical tensor " - "or .to_local() to compare the local shard." - ) - return actual, expected - class ErrorMeta(Exception): """Internal testing exception that makes that carries error metadata.""" @@ -1604,8 +1573,6 @@ def assert_close( # Hide this function from `pytest`'s traceback __tracebackhide__ = True - actual, expected = _unwrap_dtensor_for_comparison(actual, expected) - error_metas = not_close_error_metas( actual, expected, diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 2c8c723c53d0e..6b212d4a84b1d 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -28,7 +28,7 @@ torch.float8_e4m3fnuz, torch.float8_e5m2fnuz, ] -_COMPLEX_TYPES = [torch.complex32, torch.bcomplex32, torch.complex64, torch.complex128] +_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128] _BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES] _FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES] diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index c0fd47c740dc1..c5f3ad3f390dc 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -174,7 +174,6 @@ def get_all_dtypes( include_complex=True, include_complex32=False, include_qint=False, - include_bcomplex32=False, ) -> list[torch.dtype]: dtypes = get_all_int_dtypes() + get_all_fp_dtypes( include_half=include_half, include_bfloat16=include_bfloat16 @@ -182,9 +181,7 @@ def get_all_dtypes( if include_bool: dtypes.append(torch.bool) if include_complex: - dtypes += get_all_complex_dtypes( - include_complex32=include_complex32, include_bcomplex32=include_bcomplex32 - ) + dtypes += get_all_complex_dtypes(include_complex32) if include_qint: dtypes += get_all_qint_dtypes() return dtypes @@ -200,15 +197,12 @@ def get_all_math_dtypes(device) -> list[torch.dtype]: ) -def get_all_complex_dtypes( - *, include_complex32=False, include_bcomplex32=False -) -> list[torch.dtype]: - dtypes = [torch.complex64, torch.complex128] - if include_bcomplex32: - dtypes.insert(0, torch.bcomplex32) - if include_complex32: - dtypes.insert(0, torch.complex32) - return dtypes +def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]: + return ( + [torch.complex32, torch.complex64, torch.complex128] + if include_complex32 + else [torch.complex64, torch.complex128] + ) def get_all_int_dtypes() -> list[torch.dtype]: @@ -228,9 +222,15 @@ def get_all_qint_dtypes() -> list[torch.dtype]: return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4] +def highest_precision_float(device): + if torch.device(device).type == "mps": + return torch.float32 + else: + return torch.float64 + + float_to_corresponding_complex_type_map = { torch.float16: torch.complex32, - torch.bfloat16: torch.bcomplex32, torch.float32: torch.complex64, torch.float64: torch.complex128, } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ffeeea5cc0a9a..7f1d062cb46de 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -23,6 +23,7 @@ _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, + highest_precision_float, ) from torch.testing._internal.common_device_type import ( onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -2531,7 +2532,7 @@ def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs): # Noncontiguous type promoting tensors a = make_arg((3, 4, 2)) - b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double) + b = make_arg((3, 2, 2), noncontiguous=True, dtype=highest_precision_float(device)) c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2) yield SampleInput((a, b, c), kwargs={'dim': 1}) @@ -2690,7 +2691,7 @@ def error_inputs_gather(op_info, device, **kwargs): # Creates new src & idx since SampleInputs can't share tensors src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) - out = torch.empty((2, 2), device=device, dtype=torch.float64) + out = torch.empty((2, 2), device=device, dtype=torch.float16 if torch.device(device).type == 'mps' else torch.float64) yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), error_regex="Expected out tensor to have dtype") @@ -2753,7 +2754,7 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): # Error when self.dtype != src.dtype (and src is not a scalar) src = make_tensor((2, 5), device=device, dtype=torch.float32) idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) - dst = torch.zeros((3, 5), device=device, dtype=torch.double) + dst = torch.zeros((3, 5), device=device, dtype=torch.float16 if torch.device(device).type == 'mps' else torch.double) yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_regex="Expected self.dtype to be equal to src.dtype") @@ -2829,7 +2830,8 @@ def error_inputs_t(op_info, device, **kwargs): def error_inputs_multinomial(op_info, device, **kwargs): - x = torch.empty(1, 2, 3, dtype=torch.double, device=device) + dtype = highest_precision_float(device) + x = torch.empty(1, 2, 3, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(2,)), error_regex="prob_dist must be 1 or 2 dim") @@ -2837,24 +2839,24 @@ def error_inputs_multinomial(op_info, device, **kwargs): yield ErrorInput(SampleInput(x, args=(2,)), error_regex="multinomial only supports floating-point dtypes for input") - x = torch.empty(1, 2, dtype=torch.double, device=device) - y = torch.empty(1, 2, dtype=torch.double, device=device) + x = torch.empty(1, 2, dtype=dtype, device=device) + y = torch.empty(1, 2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)), error_regex="multinomial expects Long tensor out") - x = torch.empty(2, dtype=torch.double, device=device) + x = torch.empty(2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(0,)), error_regex="cannot sample n_sample <= 0 samples") - x = torch.empty(2, dtype=torch.double, device=device) + x = torch.empty(2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(-1,)), error_regex="cannot sample n_sample <= 0 samples") - x = torch.empty(2, dtype=torch.double, device=device) + x = torch.empty(2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(3, False,)), error_regex="cannot sample n_sample > prob_dist") - x = torch.empty(16777217, dtype=torch.double, device=device) + x = torch.empty(16777217, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(3,)), error_regex="number of categories cannot exceed") @@ -9412,8 +9414,9 @@ def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs): # test COMPLEX_TO_FLOAT promotion if dtype.is_complex: make = partial(make_tensor, (), device=device, requires_grad=requires_grad) - yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),)) - yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),)) + other_dtype = highest_precision_float(device) + yield SampleInput(make(dtype=dtype), args=(make(dtype=other_dtype),)) + yield SampleInput(make(dtype=other_dtype), args=(make(dtype=dtype),)) def error_inputs_l1_loss(op_info, device, **kwargs): make = partial(make_tensor, device=device, dtype=torch.float32) @@ -17694,13 +17697,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k OpInfo( "to", op=lambda x, *args, **kwargs: x.to(*args, **kwargs), - dtypes=all_types_and_complex_and( - torch.bfloat16, - torch.float16, - torch.bool, - torch.complex32, - torch.bcomplex32, - ), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, @@ -19589,8 +19586,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # RuntimeError: gather(): Yet not supported for complex DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), ), ), OpInfo('index_fill', @@ -19799,8 +19794,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # RuntimeError: scatter(): Yet not supported for complex DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), )), UnaryUfuncInfo( 'bfloat16', @@ -20494,8 +20487,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k error_inputs_func=error_inputs_multinomial, skips=( DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), # Strides are not the same! # This may not be reproducible in CI DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), @@ -20619,8 +20610,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # RuntimeError: scatter(): Yet not supported for complex DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), )), OpInfo('stack', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -20752,8 +20741,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k check_batched_forward_grad=False, assert_autodiffed=True, skips=( - # https://github.com/pytorch/pytorch/issues/89353 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), # RuntimeError: Arguments for call not valid. # Expected a value of type 'List[Tensor]' for argument # 'tensors' but instead found type 'Tensor (inferred)'. @@ -20763,8 +20750,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # see https://github.com/pytorch/pytorch/issues/99806 # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', device_type='mps', dtypes=(torch.int64,)), )), OpInfo('unbind', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -21503,9 +21488,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "test_variant_consistency_jit", dtypes=(torch.float32,), ), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), ), ), UnaryUfuncInfo('lgamma', @@ -23017,11 +22999,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "test_variant_consistency_jit", dtypes=(torch.float32, torch.complex64), ), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 # RuntimeError: norm ops are not supported for complex yet DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', device_type='mps', dtypes=(torch.int64,)), ), ), OpInfo( @@ -23346,16 +23326,17 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "_refs.lerp", torch_opinfo_name="lerp", skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype + # Exception: Dtypes torch.float32 and * are not equal! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) + dtypes=(torch.bool, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) ), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps', dtypes=(torch.bool,)), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.bfloat16, torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) + dtypes=(torch.bool, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) ), + # RuntimeError: Failed to create function state object for: abs_dense_bool_bool + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps', dtypes=(torch.bool,)), ), ), PythonRefInfo( @@ -23587,15 +23568,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), # RuntimeError: value cannot be converted to type uint8_t without overflow DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', @@ -23679,12 +23651,13 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), device_type="cuda"), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), + # RuntimeError: no _refs support for aten.copy.default DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps'), + # AssertionError: Tensor-likes are not equal! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', + dtypes=(torch.bool, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) + ), ), ), PythonRefInfo( @@ -24159,33 +24132,13 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "_refs.special.multigammaln", torch_opinfo_name="mvlgamma", torch_opinfo_variant_name="mvlgamma_p_3", - skips=skips_mvlgamma() + ( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - ), + skips=skips_mvlgamma(), ), ElementwiseUnaryPythonRefInfo( "_refs.special.multigammaln", torch_opinfo_name="mvlgamma", torch_opinfo_variant_name="mvlgamma_p_5", - skips=skips_mvlgamma() + ( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - ), + skips=skips_mvlgamma(), ), ElementwiseUnaryPythonRefInfo( "_refs.log", @@ -24234,14 +24187,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_name="log_softmax", torch_opinfo_variant_name="with_dtype", skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # AssertionError: Tensor-likes are not close! + # RuntimeError: softmax only supported for floating types DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), ), ), @@ -24388,15 +24342,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.cfloat]), - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex64,) - ), ), ), ElementwiseUnaryPythonRefInfo( @@ -24426,15 +24371,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_name="softmax", torch_opinfo_variant_name="with_dtype", skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. + # AssertionError: Tensor-likes are not close! + # RuntimeError: softmax only supported for floating types DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), ), ), @@ -24525,10 +24470,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # AssertionError: Tensor-likes are not close! + # NotImplementedError: log_softmax for complex is not supported for MPS DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), ), ), @@ -24538,11 +24484,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. + # Exception: softmax only supported for floating types + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), ), ), @@ -24552,17 +24498,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k ElementwiseUnaryPythonRefInfo( "_refs.special.logit", torch_opinfo_name="logit", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex64, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex64, torch.float16) - ), - ), ), # # Elementwise Unary nn.functional OpInfos @@ -24702,7 +24637,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_name="nn.functional.pairwise_distance", supports_out=True, skips=( - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 # RuntimeError: norm ops are not supported for complex yet DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), @@ -24739,10 +24673,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # NotImplementedError: log_softmax for complex is not supported for MPS + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), ), ), @@ -24757,17 +24692,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.nn.functional.poisson_nll_loss", torch_opinfo_name="nn.functional.poisson_nll_loss", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.prelu", @@ -24842,11 +24766,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. + # RuntimeError: softmax only supported for floating types + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), ), ), @@ -24856,11 +24780,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # RuntimeError: softmax only supported for floating types + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.complex64, torch.float16, torch.float32), + dtypes=(torch.complex64, torch.float32), ), ), ), @@ -24888,11 +24812,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.nn.functional.l1_loss", torch_opinfo_name="nn.functional.l1_loss", - skips=( - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - ), ), PythonRefInfo( "_refs.nn.functional.margin_ranking_loss", @@ -24909,13 +24828,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.nn.functional.hinge_embedding_loss", torch_opinfo_name="nn.functional.hinge_embedding_loss", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,), - ), - ), ), PythonRefInfo( "_refs.nn.functional.nll_loss", @@ -24996,15 +24908,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=(torch.complex64, torch.complex128), device_type='cpu', active_if=(IS_MACOS or IS_WINDOWS)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), ), ), ElementwiseUnaryPythonRefInfo( @@ -25045,15 +24948,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', dtypes=(torch.complex64, torch.complex128)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.float16,) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25181,15 +25075,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k decorators=( # See https://github.com/pytorch/pytorch/issues/111126 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25395,17 +25280,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k ElementwiseUnaryPythonRefInfo( "_refs.logical_not", torch_opinfo_name="logical_not", - skips=( - # RuntimeError: Undefined type ComplexDouble - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex64,) - ), - ), ), ElementwiseBinaryPythonRefInfo( "_refs.logical_or", @@ -25600,16 +25474,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'TestBinaryUfuncs', 'test_reference_numerics_small_values', dtypes=(torch.uint8,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # NotImplementedError: "_local_scalar_dense_mps" not implemented for 'ComplexHalf' - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25645,17 +25509,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.addcdiv", torch_opinfo_name="addcdiv", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), - ), ), PythonRefInfo( "_refs.addcmul", @@ -25669,18 +25522,17 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=(torch.float16,), device_type="cpu"), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', dtypes=(torch.float16,), device_type="cpu"), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', dtypes=( torch.uint8, torch.int8, torch.int64, torch.int32, - torch.int16, torch.float16, torch.bfloat16, + torch.int16, ) ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', dtypes=( torch.uint8, torch.int8, torch.int64, torch.int32, - torch.int16, torch.float16, torch.bfloat16, + torch.int16, ) ), ), @@ -26004,10 +25856,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # FIXME: AssertionError: RuntimeError not raised DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps'), + # RuntimeError: Failed to create function state object for: cat_int32_t_* + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', + dtypes=(torch.complex64, torch.complex32) + ), + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', + dtypes=(torch.complex64, torch.complex32) + ), ), ), PythonRefInfo( @@ -26193,12 +26050,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # RuntimeError: norm ops are not supported for complex yet DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', dtypes=(torch.float16,)), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16,) - ), ), ), PythonRefInfo( @@ -26687,14 +26538,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k decorators=( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), # RuntimeError: MPS device does not support addr for non-float input - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS - # framework doesn't support float64. DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', dtypes=( torch.uint8, torch.int8, torch.int64, torch.int32, - torch.int16, torch.float16, torch.complex64, torch.bool, - torch.bfloat16 + torch.int16, torch.complex64, torch.bool, ) ), ), @@ -26711,11 +26559,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # https://github.com/pytorch/pytorch/issues/77216 validate_view_consistency=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16,) - ), # RuntimeError: norm ops are not supported for complex yet DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo( diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 10807aa111e33..57303de39c4e7 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -913,7 +913,10 @@ def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad desc='3d_input_not_affine'), ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), forward_input=FunctionInput(make_input((0, 5, 9))), - desc='zero_batch')] + desc='zero_batch'), + ModuleInput(constructor_input=FunctionInput(10, bias=False), + forward_input=FunctionInput(make_input((4, 10))), + desc='affine_not_bias'),] def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs): @@ -936,7 +939,10 @@ def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad desc='not_tracking_stats'), ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), forward_input=FunctionInput(make_input((0, 5, 2, 2))), - desc='zero_batch')] + desc='zero_batch'), + ModuleInput(constructor_input=FunctionInput(3, bias=False), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='affine_not_bias'),] def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs): @@ -959,7 +965,10 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad desc='not_tracking_stats'), ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))), - desc='zero_batch')] + desc='zero_batch'), + ModuleInput(constructor_input=FunctionInput(3, bias=False), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='affine_not_bias'),] def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs): @@ -1837,6 +1846,10 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, constructor_input=FunctionInput(3, 6, 1e-3), forward_input=FunctionInput(make_input((4, 6, 5))), desc='1d_affine'), + ModuleInput( + constructor_input=FunctionInput(3, 6, 1e-3, bias=False), + forward_input=FunctionInput(make_input((4, 6, 5))), + desc='1d_affine_not_bias'), ModuleInput( constructor_input=FunctionInput(3, 12, 1e-3), forward_input=FunctionInput(make_input((4, 12))), @@ -1857,6 +1870,10 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, constructor_input=FunctionInput(3, 6, 1e-3), forward_input=FunctionInput(make_input((4, 6, 2, 3))), desc='2d_affine'), + ModuleInput( + constructor_input=FunctionInput(3, 9, 1e-3, bias=False), + forward_input=FunctionInput(make_input((4, 9, 2, 3))), + desc='2d_affine_not_bias'), ModuleInput( constructor_input=FunctionInput(3, 3, 1e-3, False), forward_input=FunctionInput(make_input((4, 3, 2, 3))), @@ -1864,8 +1881,7 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, ModuleInput( constructor_input=FunctionInput(1, 3, 1e-3, False), forward_input=FunctionInput(make_input((4, 3, 2, 3))), - desc='2d_no_affine_LN'), - ] + desc='2d_no_affine_LN'),] def module_error_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs): @@ -2054,8 +2070,21 @@ def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_g ), forward_input=FunctionInput(make_input(input_no_batch_shape)), reference_fn=no_batch_dim_reference_fn, - desc='no_batch_dim') - ] + desc='no_batch_dim'), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum, affine=True) if lazy else + FunctionInput(num_features, eps, momentum, affine=True) + ), + forward_input=FunctionInput(make_input(input_batch_shape)), + desc='affine'), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum, affine=True, bias=False) if lazy else + FunctionInput(num_features, eps, momentum, affine=True, bias=False) + ), + forward_input=FunctionInput(make_input(input_batch_shape)), + desc='affine_not_bias'),] def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index ee8ef700f2cbf..a4cad3f8492b6 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -122,6 +122,7 @@ def mps_ops_modifier( "nn.functional.conv_transpose2d", "nn.functional.conv_transpose3d", "nn.functional.feature_alpha_dropoutwithout_train", + "nn.functional.l1_loss", "nn.functional.padcircular", "nn.functional.softminwith_dtype", "nn.functional.softsign", @@ -1008,11 +1009,6 @@ def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "clamp_max", "clamp_min", "masked_scatter", - # unsupported float64 dtype - "multinomial", - "gather", - "scatter", - "scatter_add", # MPS does not support tensor dimensions > 16 "amax", "amin", diff --git a/torch/testing/_internal/common_ops_unbacked.py b/torch/testing/_internal/common_ops_unbacked.py index 5dcf058a05314..56a47a263cfc4 100644 --- a/torch/testing/_internal/common_ops_unbacked.py +++ b/torch/testing/_internal/common_ops_unbacked.py @@ -153,6 +153,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None): xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.fractional_max_pool3d"), xfail("nn.functional.gaussian_nll_loss"), + xfail("nn.functional.glu"), xfail("nn.functional.grid_sample"), xfail("nn.functional.group_norm"), xfail("nn.functional.huber_loss"), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 193f8085c703b..8b8507433c7d8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -82,14 +82,13 @@ ) from torch.testing import make_tensor from torch.testing._comparison import ( - _unwrap_dtensor_for_comparison, BooleanPair, NonePair, - not_close_error_metas, NumberPair, Pair, TensorLikePair, ) +from torch.testing._comparison import not_close_error_metas from torch.testing._internal.common_dtype import get_all_dtypes from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree @@ -4316,8 +4315,6 @@ def to_list(input): if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided: y = y.unbind() - x, y = _unwrap_dtensor_for_comparison(x, y) - error_metas = not_close_error_metas( x, y, diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index b6064e2fca1ad..26fb73497f978 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -2800,20 +2800,6 @@ def make_input(): DecorateInfo( unittest.expectedFailure, "TestCommon", "test_python_ref_errors" ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.bfloat16, torch.float16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.bfloat16, torch.float16), - ), ), ), PythonRefInfo( @@ -2826,23 +2812,6 @@ def make_input(): "_refs.linalg.vecdot", torch_opinfo_name="linalg.vecdot", op_db=op_db, - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ReductionPythonRefInfo( "_refs.linalg.vector_norm", @@ -2850,21 +2819,6 @@ def make_input(): supports_out=True, op_db=op_db, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16,), - ), # Exception: norm ops are not supported for complex yet DecorateInfo( unittest.expectedFailure, diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index d375ba21358ff..1626ead643244 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -878,23 +878,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): }, ), ), - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.bessel_j1", @@ -908,23 +891,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): }, ), ), - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.entr", @@ -995,21 +961,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): "test_reference_numerics_large", dtypes=(torch.int8,), ), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16,), - ), ), ), ElementwiseUnaryPythonRefInfo( @@ -1041,23 +992,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): "_refs.special.ndtr", torch_opinfo_name="special.ndtr", op_db=op_db, - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.ndtri", diff --git a/torch/utils/_dtype_abbrs.py b/torch/utils/_dtype_abbrs.py index 0b6210be113ca..c4eb9c56671db 100644 --- a/torch/utils/_dtype_abbrs.py +++ b/torch/utils/_dtype_abbrs.py @@ -16,7 +16,6 @@ torch.complex32: "c32", torch.complex64: "c64", torch.complex128: "c128", - torch.bcomplex32: "bc32", torch.int8: "i8", torch.int16: "i16", torch.int32: "i32", diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 8d42d14968de7..963d0b5a20065 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -372,6 +372,9 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " || ", precedence(expr)) + def _print_Piecewise(self, expr: sympy.Expr) -> str: # Convert Piecewise(expr_cond_pairs) to nested ternary operators # Piecewise((e1, c1), (e2, c2), ..., (eN, cN))